[ML] explicitly disallow partial results in datafeed extractors (#55537) (#55585)

Instead of doing our own checks against REST status, shard counts, and shard failures, this commit changes all our extractor search requests to set `.setAllowPartialSearchResults(false)`.

- Scrolls are automatically cleared when a search failure occurs with `.setAllowPartialSearchResults(false)` set.
- Code error handling is simplified

closes https://github.com/elastic/elasticsearch/issues/40793
This commit is contained in:
Benjamin Trent 2020-04-22 09:07:44 -04:00 committed by GitHub
parent 810caf5ffe
commit 7c81cd7833
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 111 additions and 217 deletions

View File

@ -5,17 +5,12 @@
*/
package org.elasticsearch.xpack.core.ml.datafeed.extractor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.Rounding;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder;
@ -23,9 +18,7 @@ import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregati
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.TimeUnit;
@ -34,7 +27,6 @@ import java.util.concurrent.TimeUnit;
*/
public final class ExtractorUtils {
private static final Logger LOGGER = LogManager.getLogger(ExtractorUtils.class);
private static final String EPOCH_MILLIS = "epoch_millis";
private ExtractorUtils() {}
@ -47,25 +39,6 @@ public final class ExtractorUtils {
return new BoolQueryBuilder().filter(userQuery).filter(timeQuery);
}
/**
* Checks that a {@link SearchResponse} has an OK status code and no shard failures
*/
public static void checkSearchWasSuccessful(String jobId, SearchResponse searchResponse) throws IOException {
if (searchResponse.status() != RestStatus.OK) {
throw new IOException("[" + jobId + "] Search request returned status code: " + searchResponse.status()
+ ". Response was:\n" + searchResponse.toString());
}
ShardSearchFailure[] shardFailures = searchResponse.getShardFailures();
if (shardFailures != null && shardFailures.length > 0) {
LOGGER.error("[{}] Search request returned shard failures: {}", jobId, Arrays.toString(shardFailures));
throw new IOException(ExceptionsHelper.shardFailuresToErrorMsg(jobId, shardFailures));
}
int unavailableShards = searchResponse.getTotalShards() - searchResponse.getSuccessfulShards();
if (unavailableShards > 0) {
throw new IOException("[" + jobId + "] Search request encountered [" + unavailableShards + "] unavailable shards");
}
}
/**
* Find the (date) histogram in {@code aggFactory} and extract its interval.
* Throws if there is no (date) histogram or if the histogram has sibling

View File

@ -107,12 +107,13 @@ abstract class AbstractAggregationDataExtractor<T extends ActionRequestBuilder<S
return Optional.ofNullable(processNextBatch());
}
private Aggregations search() throws IOException {
private Aggregations search() {
LOGGER.debug("[{}] Executing aggregated search", context.jobId);
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest(buildBaseSearchSource()));
T searchRequest = buildSearchRequest(buildBaseSearchSource());
assert searchRequest.request().allowPartialSearchResults() == false;
SearchResponse searchResponse = executeSearchRequest(searchRequest);
LOGGER.debug("[{}] Search response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
return validateAggs(searchResponse.getAggregations());
}
@ -166,10 +167,6 @@ abstract class AbstractAggregationDataExtractor<T extends ActionRequestBuilder<S
return new ByteArrayInputStream(outputStream.toByteArray());
}
protected long getHistogramInterval() {
return ExtractorUtils.getHistogramIntervalMillis(context.aggs);
}
public AggregationDataExtractorContext getContext() {
return context;
}

View File

@ -29,6 +29,7 @@ class AggregationDataExtractor extends AbstractAggregationDataExtractor<SearchRe
return new SearchRequestBuilder(client, SearchAction.INSTANCE)
.setSource(searchSourceBuilder)
.setIndicesOptions(context.indicesOptions)
.setAllowPartialSearchResults(false)
.setIndices(context.indices);
}
}

View File

@ -63,8 +63,7 @@ class AggregationToJsonProcessor {
* @param includeDocCount whether to include the doc_count
* @param startTime buckets with a timestamp before this time are discarded
*/
AggregationToJsonProcessor(String timeField, Set<String> fields, boolean includeDocCount, long startTime)
throws IOException {
AggregationToJsonProcessor(String timeField, Set<String> fields, boolean includeDocCount, long startTime) {
this.timeField = Objects.requireNonNull(timeField);
this.fields = Objects.requireNonNull(fields);
this.includeDocCount = includeDocCount;
@ -279,7 +278,7 @@ class AggregationToJsonProcessor {
* Adds a leaf key-value. It returns {@code true} if the key added or {@code false} when nothing was added.
* Non-finite metric values are not added.
*/
private boolean processLeaf(Aggregation agg) throws IOException {
private boolean processLeaf(Aggregation agg) {
if (agg instanceof NumericMetricsAggregation.SingleValue) {
return processSingleValue((NumericMetricsAggregation.SingleValue) agg);
} else if (agg instanceof Percentiles) {
@ -291,7 +290,7 @@ class AggregationToJsonProcessor {
}
}
private boolean processSingleValue(NumericMetricsAggregation.SingleValue singleValue) throws IOException {
private boolean processSingleValue(NumericMetricsAggregation.SingleValue singleValue) {
return addMetricIfFinite(singleValue.getName(), singleValue.value());
}
@ -311,7 +310,7 @@ class AggregationToJsonProcessor {
return false;
}
private boolean processPercentiles(Percentiles percentiles) throws IOException {
private boolean processPercentiles(Percentiles percentiles) {
Iterator<Percentile> percentileIterator = percentiles.iterator();
boolean aggregationAdded = addMetricIfFinite(percentiles.getName(), percentileIterator.next().getValue());
if (percentileIterator.hasNext()) {

View File

@ -28,6 +28,7 @@ class RollupDataExtractor extends AbstractAggregationDataExtractor<RollupSearchA
protected RollupSearchAction.RequestBuilder buildSearchRequest(SearchSourceBuilder searchSourceBuilder) {
SearchRequest searchRequest = new SearchRequest().indices(context.indices)
.indicesOptions(context.indicesOptions)
.allowPartialSearchResults(false)
.source(searchSourceBuilder);
return new RollupSearchAction.RequestBuilder(client, searchRequest);

View File

@ -114,7 +114,7 @@ public class ChunkedDataExtractor implements DataExtractor {
return getNextStream();
}
private void setUpChunkedSearch() throws IOException {
private void setUpChunkedSearch() {
DataSummary dataSummary = dataSummaryFactory.buildDataSummary();
if (dataSummary.hasData()) {
currentStart = context.timeAligner.alignToFloor(dataSummary.earliestTime());
@ -196,21 +196,18 @@ public class ChunkedDataExtractor implements DataExtractor {
* So, if we need to gather an appropriate chunked time for aggregations, we can utilize the AggregatedDataSummary
*
* @return DataSummary object
* @throws IOException when timefield range search fails
*/
private DataSummary buildDataSummary() throws IOException {
private DataSummary buildDataSummary() {
return context.hasAggregations ? newAggregatedDataSummary() : newScrolledDataSummary();
}
private DataSummary newScrolledDataSummary() throws IOException {
private DataSummary newScrolledDataSummary() {
SearchRequestBuilder searchRequestBuilder = rangeSearchRequest();
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
LOGGER.debug("[{}] Scrolling Data summary response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
Aggregations aggregations = searchResponse.getAggregations();
long earliestTime = 0;
long latestTime = 0;
@ -224,7 +221,7 @@ public class ChunkedDataExtractor implements DataExtractor {
return new ScrolledDataSummary(earliestTime, latestTime, totalHits);
}
private DataSummary newAggregatedDataSummary() throws IOException {
private DataSummary newAggregatedDataSummary() {
// TODO: once RollupSearchAction is changed from indices:admin* to indices:data/read/* this branch is not needed
ActionRequestBuilder<SearchRequest, SearchResponse> searchRequestBuilder =
dataExtractorFactory instanceof RollupDataExtractorFactory ? rollupRangeSearchRequest() : rangeSearchRequest();
@ -232,8 +229,6 @@ public class ChunkedDataExtractor implements DataExtractor {
LOGGER.debug("[{}] Aggregating Data summary response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
Aggregations aggregations = searchResponse.getAggregations();
Min min = aggregations.get(EARLIEST_TIME);
Max max = aggregations.get(LATEST_TIME);
@ -253,12 +248,14 @@ public class ChunkedDataExtractor implements DataExtractor {
.setIndices(context.indices)
.setIndicesOptions(context.indicesOptions)
.setSource(rangeSearchBuilder())
.setAllowPartialSearchResults(false)
.setTrackTotalHits(true);
}
private RollupSearchAction.RequestBuilder rollupRangeSearchRequest() {
SearchRequest searchRequest = new SearchRequest().indices(context.indices)
.indicesOptions(context.indicesOptions)
.allowPartialSearchResults(false)
.source(rangeSearchBuilder());
return new RollupSearchAction.RequestBuilder(client, searchRequest);
}

View File

@ -102,9 +102,13 @@ class ScrollDataExtractor implements DataExtractor {
return scrollId == null ?
Optional.ofNullable(initScroll(context.start)) : Optional.ofNullable(continueScroll());
} catch (Exception e) {
// In case of error make sure we clear the scroll context
clearScroll();
throw e;
scrollId = null;
if (searchHasShardFailure) {
throw e;
}
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
markScrollAsErrored();
return Optional.ofNullable(initScroll(lastTimestamp == null ? context.start : lastTimestamp));
}
}
@ -127,6 +131,7 @@ class ScrollDataExtractor implements DataExtractor {
.setIndices(context.indices)
.setIndicesOptions(context.indicesOptions)
.setSize(context.scrollSize)
.setAllowPartialSearchResults(false)
.setQuery(ExtractorUtils.wrapInTimeRangeQuery(
context.query, context.extractedFields.timeField(), start, context.end));
@ -147,14 +152,6 @@ class ScrollDataExtractor implements DataExtractor {
private InputStream processSearchResponse(SearchResponse searchResponse) throws IOException {
scrollId = searchResponse.getScrollId();
if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
markScrollAsErrored();
return initScroll(lastTimestamp == null ? context.start : lastTimestamp);
}
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
if (searchResponse.getHits().getHits().length == 0) {
hasNext = false;
clearScroll();
@ -190,24 +187,23 @@ class ScrollDataExtractor implements DataExtractor {
try {
searchResponse = executeSearchScrollRequest(scrollId);
} catch (SearchPhaseExecutionException searchExecutionException) {
if (searchHasShardFailure == false) {
LOGGER.debug("[{}] Reinitializing scroll due to SearchPhaseExecutionException", context.jobId);
markScrollAsErrored();
searchResponse =
executeSearchRequest(buildSearchRequest(lastTimestamp == null ? context.start : lastTimestamp));
} else {
if (searchHasShardFailure) {
throw searchExecutionException;
}
LOGGER.debug("[{}] search failed due to SearchPhaseExecutionException. Will attempt again with new scroll",
context.jobId);
markScrollAsErrored();
searchResponse = executeSearchRequest(buildSearchRequest(lastTimestamp == null ? context.start : lastTimestamp));
}
LOGGER.debug("[{}] Search response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
return processSearchResponse(searchResponse);
}
private void markScrollAsErrored() {
void markScrollAsErrored() {
// This could be a transient error with the scroll Id.
// Reinitialise the scroll and try again but only once.
clearScroll();
scrollId = null;
if (lastTimestamp != null) {
lastTimestamp++;
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.ml.datafeed.extractor.aggregation;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
@ -64,6 +65,7 @@ public class AggregationDataExtractorTests extends ESTestCase {
private class TestDataExtractor extends AggregationDataExtractor {
private SearchResponse nextResponse;
private SearchPhaseExecutionException ex;
TestDataExtractor(long start, long end) {
super(testClient, createContext(start, end), timingStatsReporter);
@ -72,12 +74,19 @@ public class AggregationDataExtractorTests extends ESTestCase {
@Override
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
capturedSearchRequests.add(searchRequestBuilder);
if (ex != null) {
throw ex;
}
return nextResponse;
}
void setNextResponse(SearchResponse searchResponse) {
nextResponse = searchResponse;
}
void setNextResponseToError(SearchPhaseExecutionException ex) {
this.ex = ex;
}
}
@Before
@ -246,29 +255,12 @@ public class AggregationDataExtractorTests extends ESTestCase {
assertThat(capturedSearchRequests.size(), equalTo(1));
}
public void testExtractionGivenSearchResponseHasError() throws IOException {
public void testExtractionGivenSearchResponseHasError() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createErrorResponse());
extractor.setNextResponseToError(new SearchPhaseExecutionException("phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
}
public void testExtractionGivenSearchResponseHasShardFailures() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithShardFailures());
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
}
public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithUnavailableShards(2));
assertThat(extractor.hasNext(), is(true));
IOException e = expectThrows(IOException.class, extractor::next);
assertThat(e.getMessage(), equalTo("[" + jobId + "] Search request encountered [2] unavailable shards"));
expectThrows(SearchPhaseExecutionException.class, extractor::next);
}
private AggregationDataExtractorContext createContext(long start, long end) {
@ -295,29 +287,6 @@ public class AggregationDataExtractorTests extends ESTestCase {
return searchResponse;
}
private SearchResponse createErrorResponse() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.INTERNAL_SERVER_ERROR);
return searchResponse;
}
private SearchResponse createResponseWithShardFailures() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.OK);
when(searchResponse.getShardFailures()).thenReturn(
new ShardSearchFailure[] { new ShardSearchFailure(new RuntimeException("shard failed"))});
return searchResponse;
}
private SearchResponse createResponseWithUnavailableShards(int unavailableShards) {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.OK);
when(searchResponse.getSuccessfulShards()).thenReturn(3);
when(searchResponse.getTotalShards()).thenReturn(3 + unavailableShards);
when(searchResponse.getTook()).thenReturn(TimeValue.timeValueMillis(randomNonNegativeLong()));
return searchResponse;
}
private static String asString(InputStream inputStream) throws IOException {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
return reader.lines().collect(Collectors.joining("\n"));

View File

@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.chunked;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
@ -62,6 +63,7 @@ public class ChunkedDataExtractorTests extends ESTestCase {
private class TestDataExtractor extends ChunkedDataExtractor {
private SearchResponse nextResponse;
private SearchPhaseExecutionException ex;
TestDataExtractor(long start, long end) {
super(client, dataExtractorFactory, createContext(start, end), timingStatsReporter);
@ -74,12 +76,19 @@ public class ChunkedDataExtractorTests extends ESTestCase {
@Override
protected SearchResponse executeSearchRequest(ActionRequestBuilder<SearchRequest, SearchResponse> searchRequestBuilder) {
capturedSearchRequests.add(searchRequestBuilder.request());
if (ex != null) {
throw ex;
}
return nextResponse;
}
void setNextResponse(SearchResponse searchResponse) {
nextResponse = searchResponse;
}
void setNextResponseToError(SearchPhaseExecutionException ex) {
this.ex = ex;
}
}
@Before
@ -485,22 +494,13 @@ public class ChunkedDataExtractorTests extends ESTestCase {
Mockito.verifyNoMoreInteractions(dataExtractorFactory);
}
public void testDataSummaryRequestIsNotOk() {
public void testDataSummaryRequestIsFailed() {
chunkSpan = TimeValue.timeValueSeconds(2);
TestDataExtractor extractor = new TestDataExtractor(1000L, 2300L);
extractor.setNextResponse(createErrorResponse());
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
}
public void testDataSummaryRequestHasShardFailures() {
chunkSpan = TimeValue.timeValueSeconds(2);
TestDataExtractor extractor = new TestDataExtractor(1000L, 2300L);
extractor.setNextResponse(createResponseWithShardFailures());
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
expectThrows(SearchPhaseExecutionException.class, extractor::next);
}
private SearchResponse createSearchResponse(long totalHits, long earliestTime, long latestTime) {
@ -545,20 +545,6 @@ public class ChunkedDataExtractorTests extends ESTestCase {
return searchResponse;
}
private SearchResponse createErrorResponse() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.INTERNAL_SERVER_ERROR);
return searchResponse;
}
private SearchResponse createResponseWithShardFailures() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.OK);
when(searchResponse.getShardFailures()).thenReturn(
new ShardSearchFailure[] { new ShardSearchFailure(new RuntimeException("shard failed"))});
return searchResponse;
}
private ChunkedDataExtractorContext createContext(long start, long end) {
return createContext(start, end, false, null);
}

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.search.ClearScrollAction;
import org.elasticsearch.action.search.ClearScrollRequest;
@ -16,6 +17,7 @@ import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
@ -81,7 +83,8 @@ public class ScrollDataExtractorTests extends ESTestCase {
private class TestDataExtractor extends ScrollDataExtractor {
private Queue<SearchResponse> responses = new LinkedList<>();
private Queue<Tuple<SearchResponse, ElasticsearchException>> responses = new LinkedList<>();
private int numScrollReset;
TestDataExtractor(long start, long end) {
this(createContext(start, end));
@ -100,22 +103,39 @@ public class ScrollDataExtractorTests extends ESTestCase {
@Override
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
capturedSearchRequests.add(searchRequestBuilder);
return responses.remove();
Tuple<SearchResponse, ElasticsearchException> responseOrException = responses.remove();
if (responseOrException.v2() != null) {
throw responseOrException.v2();
}
return responseOrException.v1();
}
@Override
protected SearchResponse executeSearchScrollRequest(String scrollId) {
capturedContinueScrollIds.add(scrollId);
SearchResponse searchResponse = responses.remove();
if (searchResponse == null) {
throw new SearchPhaseExecutionException("foo", "bar", new ShardSearchFailure[] {});
} else {
return searchResponse;
Tuple<SearchResponse, ElasticsearchException> responseOrException = responses.remove();
if (responseOrException.v2() != null) {
throw responseOrException.v2();
}
return responseOrException.v1();
}
@Override
void markScrollAsErrored() {
++numScrollReset;
super.markScrollAsErrored();
}
int getNumScrollReset() {
return numScrollReset;
}
void setNextResponse(SearchResponse searchResponse) {
responses.add(searchResponse);
responses.add(Tuple.tuple(searchResponse, null));
}
void setNextResponseToError(ElasticsearchException ex) {
responses.add(Tuple.tuple(null, ex));
}
public long getInitScrollStartTime() {
@ -280,12 +300,13 @@ public class ScrollDataExtractorTests extends ESTestCase {
assertThat(capturedClearScrollIds.get(0), equalTo(response2.getScrollId()));
}
public void testExtractionGivenInitSearchResponseHasError() throws IOException {
public void testExtractionGivenInitSearchResponseHasError() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createErrorResponse());
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
expectThrows(SearchPhaseExecutionException.class, extractor::next);
}
public void testExtractionGivenContinueScrollResponseHasError() throws IOException {
@ -302,36 +323,21 @@ public class ScrollDataExtractorTests extends ESTestCase {
Optional<InputStream> stream = extractor.next();
assertThat(stream.isPresent(), is(true));
extractor.setNextResponse(createErrorResponse());
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
expectThrows(SearchPhaseExecutionException.class, extractor::next);
List<String> capturedClearScrollIds = getCapturedClearScrollIds();
assertThat(capturedClearScrollIds.size(), equalTo(1));
assertThat(extractor.getNumScrollReset(), equalTo(1));
}
public void testExtractionGivenInitSearchResponseHasShardFailures() throws IOException {
public void testExtractionGivenInitSearchResponseEncounteredFailure() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
List<String> capturedClearScrollIds = getCapturedClearScrollIds();
// We should clear the scroll context twice: once for the first search when we retry
// and once after the retry where we'll have an exception
assertThat(capturedClearScrollIds.size(), equalTo(2));
}
public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() throws IOException {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithUnavailableShards(1));
extractor.setNextResponse(createResponseWithUnavailableShards(1));
assertThat(extractor.hasNext(), is(true));
IOException e = expectThrows(IOException.class, extractor::next);
assertThat(e.getMessage(), equalTo("[" + jobId + "] Search request encountered [1] unavailable shards"));
expectThrows(SearchPhaseExecutionException.class, extractor::next);
}
public void testResetScrollAfterShardFailure() throws IOException {
@ -343,9 +349,9 @@ public class ScrollDataExtractorTests extends ESTestCase {
Arrays.asList("b1", "b2")
);
extractor.setNextResponse(goodResponse);
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
extractor.setNextResponse(goodResponse);
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
// first response is good
assertThat(extractor.hasNext(), is(true));
@ -357,13 +363,12 @@ public class ScrollDataExtractorTests extends ESTestCase {
assertThat(output.isPresent(), is(true));
// A second failure is not tolerated
assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
expectThrows(SearchPhaseExecutionException.class, extractor::next);
List<String> capturedClearScrollIds = getCapturedClearScrollIds();
assertThat(capturedClearScrollIds.size(), equalTo(2));
assertThat(extractor.getNumScrollReset(), equalTo(1));
}
public void testResetScollUsesLastResultTimestamp() throws IOException {
public void testResetScrollUsesLastResultTimestamp() throws IOException {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
SearchResponse goodResponse = createSearchResponse(
@ -373,14 +378,14 @@ public class ScrollDataExtractorTests extends ESTestCase {
);
extractor.setNextResponse(goodResponse);
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponseToError(new ElasticsearchException("something not search phase exception"));
extractor.setNextResponseToError(new ElasticsearchException("something not search phase exception"));
Optional<InputStream> output = extractor.next();
assertThat(output.isPresent(), is(true));
assertEquals(1000L, extractor.getInitScrollStartTime());
expectThrows(IOException.class, () -> extractor.next());
expectThrows(ElasticsearchException.class, extractor::next);
// the new start time after error is the last record timestamp +1
assertEquals(1201L, extractor.getInitScrollStartTime());
}
@ -400,9 +405,9 @@ public class ScrollDataExtractorTests extends ESTestCase {
);
extractor.setNextResponse(firstResponse);
extractor.setNextResponse(null); // this will throw a SearchPhaseExecutionException
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
extractor.setNextResponse(secondResponse);
extractor.setNextResponse(null); // this will throw a SearchPhaseExecutionException
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
// first response is good
@ -418,22 +423,18 @@ public class ScrollDataExtractorTests extends ESTestCase {
assertThat(extractor.hasNext(), is(true));
expectThrows(SearchPhaseExecutionException.class, extractor::next);
List<String> capturedClearScrollIds = getCapturedClearScrollIds();
assertThat(capturedClearScrollIds.size(), equalTo(2));
assertThat(extractor.getNumScrollReset(), equalTo(1));
}
public void testSearchPhaseExecutionExceptionOnInitScroll() throws IOException {
public void testSearchPhaseExecutionExceptionOnInitScroll() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
extractor.setNextResponseToError(new SearchPhaseExecutionException("search phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
expectThrows(IOException.class, extractor::next);
expectThrows(SearchPhaseExecutionException.class, extractor::next);
List<String> capturedClearScrollIds = getCapturedClearScrollIds();
// We should clear the scroll context twice: once for the first search when we retry
// and once after the retry where we'll have an exception
assertThat(capturedClearScrollIds.size(), equalTo(2));
assertThat(extractor.getNumScrollReset(), equalTo(1));
}
public void testDomainSplitScriptField() throws IOException {
@ -519,32 +520,6 @@ public class ScrollDataExtractorTests extends ESTestCase {
return searchResponse;
}
private SearchResponse createErrorResponse() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.INTERNAL_SERVER_ERROR);
when(searchResponse.getScrollId()).thenReturn(randomAlphaOfLength(1000));
return searchResponse;
}
private SearchResponse createResponseWithShardFailures() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.OK);
when(searchResponse.getShardFailures()).thenReturn(
new ShardSearchFailure[] { new ShardSearchFailure(new RuntimeException("shard failed"))});
when(searchResponse.getFailedShards()).thenReturn(1);
when(searchResponse.getScrollId()).thenReturn(randomAlphaOfLength(1000));
return searchResponse;
}
private SearchResponse createResponseWithUnavailableShards(int unavailableShards) {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.OK);
when(searchResponse.getSuccessfulShards()).thenReturn(2);
when(searchResponse.getTotalShards()).thenReturn(2 + unavailableShards);
when(searchResponse.getFailedShards()).thenReturn(unavailableShards);
return searchResponse;
}
private List<String> getCapturedClearScrollIds() {
return capturedClearScrollRequests.getAllValues().stream().map(r -> r.getScrollIds().get(0)).collect(Collectors.toList());
}