From b5efaf6e3bc01c999a3764a080314284f95a0a4d Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 12 Nov 2020 11:08:40 +0200 Subject: [PATCH] [7.10][ML] Protect against stack overflow while loading DFA data (#64947) (#64956) If we encounter an exception during extracting data in a data frame analytics job, we retry once. However, we were not catching exceptions thrown from processing the search response. This may result in an infinite loop that causes a stack overflow. This commit fixes this problem. Backport of #64947 --- .../extractor/DataFrameDataExtractor.java | 16 ++++++---- .../DataFrameDataExtractorTests.java | 32 ++++++++++++++++++- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 015677dd3ca..defd071bfc9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -65,7 +65,7 @@ public class DataFrameDataExtractor { private long lastSortKey = -1; private boolean isCancelled; private boolean hasNext; - private boolean searchHasShardFailure; + private boolean hasPreviousSearchFailed; private final CachedSupplier trainTestSplitter; // These are fields that are sent directly to the analytics process // They are not passed through a feature_processor @@ -82,7 +82,7 @@ public class DataFrameDataExtractor { this.extractedFieldsByName = new LinkedHashMap<>(); context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f)); hasNext = true; - searchHasShardFailure = false; + hasPreviousSearchFailed = false; this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create); } @@ -129,12 +129,14 @@ public class DataFrameDataExtractor { SearchResponse searchResponse = request.get(); LOGGER.debug("[{}] Search response was obtained", context.jobId); - // Request was successful so we can restore the flag to retry if a future failure occurs - searchHasShardFailure = false; + List rows = processSearchResponse(searchResponse); - return processSearchResponse(searchResponse); + // Request was successfully executed and processed so we can restore the flag to retry if a future failure occurs + hasPreviousSearchFailed = false; + + return rows; } catch (Exception e) { - if (searchHasShardFailure) { + if (hasPreviousSearchFailed) { throw e; } LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e); @@ -286,7 +288,7 @@ public class DataFrameDataExtractor { private void markScrollAsErrored() { // This could be a transient error with the scroll Id. // Reinitialise the scroll and try again but only once. - searchHasShardFailure = true; + hasPreviousSearchFailed = true; } public List getFieldNames() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index c27e46f0761..6fcbbd79efd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -55,6 +55,8 @@ import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -539,6 +541,26 @@ public class DataFrameDataExtractorTests extends ESTestCase { assertThat(rows.get().get(2).shouldSkip(), is(false)); } + public void testExtractionWithProcessedFieldThrows() { + ProcessedField processedField = mock(ProcessedField.class); + doThrow(new RuntimeException("process field error")).when(processedField).value(any(), any()); + + extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("keyword")), + new DocValueField("field_2", Collections.singleton("keyword"))), + Collections.singletonList(processedField), + Collections.emptyMap()); + + TestExtractor dataExtractor = createExtractor(true, true); + + SearchResponse response = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3)); + dataExtractor.setAlwaysResponse(response); + + assertThat(dataExtractor.hasNext(), is(true)); + + expectThrows(RuntimeException.class, () -> dataExtractor.next()); + } + private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory); @@ -594,19 +616,27 @@ public class DataFrameDataExtractorTests extends ESTestCase { private Queue responses = new LinkedList<>(); private List capturedSearchRequests = new ArrayList<>(); + private SearchResponse alwaysResponse; TestExtractor(Client client, DataFrameDataExtractorContext context) { super(client, context); } void setNextResponse(SearchResponse searchResponse) { + if (alwaysResponse != null) { + throw new IllegalStateException("Should not set next response when an always response has been set"); + } responses.add(searchResponse); } + void setAlwaysResponse(SearchResponse searchResponse) { + alwaysResponse = searchResponse; + } + @Override protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { capturedSearchRequests.add(searchRequestBuilder); - SearchResponse searchResponse = responses.remove(); + SearchResponse searchResponse = alwaysResponse == null ? responses.remove() : alwaysResponse; if (searchResponse.getShardFailures() != null) { throw new RuntimeException(searchResponse.getShardFailures()[0].getCause()); }