[7.10][ML] Protect against stack overflow while loading DFA data (#64947) (#64956)

If we encounter an exception during extracting data in a data
frame analytics job, we retry once. However, we were not catching
exceptions thrown from processing the search response. This may
result in an infinite loop that causes a stack overflow.

This commit fixes this problem.

Backport of #64947
This commit is contained in:
Dimitris Athanasiou 2020-11-12 11:08:40 +02:00 committed by GitHub
parent e4b77bcd38
commit b5efaf6e3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 8 deletions

View File

@ -65,7 +65,7 @@ public class DataFrameDataExtractor {
private long lastSortKey = -1; private long lastSortKey = -1;
private boolean isCancelled; private boolean isCancelled;
private boolean hasNext; private boolean hasNext;
private boolean searchHasShardFailure; private boolean hasPreviousSearchFailed;
private final CachedSupplier<TrainTestSplitter> trainTestSplitter; private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
// These are fields that are sent directly to the analytics process // These are fields that are sent directly to the analytics process
// They are not passed through a feature_processor // They are not passed through a feature_processor
@ -82,7 +82,7 @@ public class DataFrameDataExtractor {
this.extractedFieldsByName = new LinkedHashMap<>(); this.extractedFieldsByName = new LinkedHashMap<>();
context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f)); context.extractedFields.getAllFields().forEach(f -> this.extractedFieldsByName.put(f.getName(), f));
hasNext = true; hasNext = true;
searchHasShardFailure = false; hasPreviousSearchFailed = false;
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create); this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
} }
@ -129,12 +129,14 @@ public class DataFrameDataExtractor {
SearchResponse searchResponse = request.get(); SearchResponse searchResponse = request.get();
LOGGER.debug("[{}] Search response was obtained", context.jobId); LOGGER.debug("[{}] Search response was obtained", context.jobId);
// Request was successful so we can restore the flag to retry if a future failure occurs List<Row> rows = processSearchResponse(searchResponse);
searchHasShardFailure = false;
return processSearchResponse(searchResponse); // Request was successfully executed and processed so we can restore the flag to retry if a future failure occurs
hasPreviousSearchFailed = false;
return rows;
} catch (Exception e) { } catch (Exception e) {
if (searchHasShardFailure) { if (hasPreviousSearchFailed) {
throw e; throw e;
} }
LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e); LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e);
@ -286,7 +288,7 @@ public class DataFrameDataExtractor {
private void markScrollAsErrored() { private void markScrollAsErrored() {
// This could be a transient error with the scroll Id. // This could be a transient error with the scroll Id.
// Reinitialise the scroll and try again but only once. // Reinitialise the scroll and try again but only once.
searchHasShardFailure = true; hasPreviousSearchFailed = true;
} }
public List<String> getFieldNames() { public List<String> getFieldNames() {

View File

@ -55,6 +55,8 @@ import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
@ -539,6 +541,26 @@ public class DataFrameDataExtractorTests extends ESTestCase {
assertThat(rows.get().get(2).shouldSkip(), is(false)); assertThat(rows.get().get(2).shouldSkip(), is(false));
} }
public void testExtractionWithProcessedFieldThrows() {
ProcessedField processedField = mock(ProcessedField.class);
doThrow(new RuntimeException("process field error")).when(processedField).value(any(), any());
extractedFields = new ExtractedFields(Arrays.asList(
new DocValueField("field_1", Collections.singleton("keyword")),
new DocValueField("field_2", Collections.singleton("keyword"))),
Collections.singletonList(processedField),
Collections.emptyMap());
TestExtractor dataExtractor = createExtractor(true, true);
SearchResponse response = createSearchResponse(Arrays.asList(1_1, null, 1_3), Arrays.asList(2_1, 2_2, 2_3));
dataExtractor.setAlwaysResponse(response);
assertThat(dataExtractor.hasNext(), is(true));
expectThrows(RuntimeException.class, () -> dataExtractor.next());
}
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize, DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory); headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
@ -594,19 +616,27 @@ public class DataFrameDataExtractorTests extends ESTestCase {
private Queue<SearchResponse> responses = new LinkedList<>(); private Queue<SearchResponse> responses = new LinkedList<>();
private List<SearchRequestBuilder> capturedSearchRequests = new ArrayList<>(); private List<SearchRequestBuilder> capturedSearchRequests = new ArrayList<>();
private SearchResponse alwaysResponse;
TestExtractor(Client client, DataFrameDataExtractorContext context) { TestExtractor(Client client, DataFrameDataExtractorContext context) {
super(client, context); super(client, context);
} }
void setNextResponse(SearchResponse searchResponse) { void setNextResponse(SearchResponse searchResponse) {
if (alwaysResponse != null) {
throw new IllegalStateException("Should not set next response when an always response has been set");
}
responses.add(searchResponse); responses.add(searchResponse);
} }
void setAlwaysResponse(SearchResponse searchResponse) {
alwaysResponse = searchResponse;
}
@Override @Override
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
capturedSearchRequests.add(searchRequestBuilder); capturedSearchRequests.add(searchRequestBuilder);
SearchResponse searchResponse = responses.remove(); SearchResponse searchResponse = alwaysResponse == null ? responses.remove() : alwaysResponse;
if (searchResponse.getShardFailures() != null) { if (searchResponse.getShardFailures() != null) {
throw new RuntimeException(searchResponse.getShardFailures()[0].getCause()); throw new RuntimeException(searchResponse.getShardFailures()[0].getCause());
} }