diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 6872585929b..fd24eb5d268 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -67,6 +67,7 @@ public class DataFrameDataExtractor { private final Client client; private final DataFrameDataExtractorContext context; private String scrollId; + private String lastSortKey; private boolean isCancelled; private boolean hasNext; private boolean searchHasShardFailure; @@ -122,7 +123,9 @@ public class DataFrameDataExtractor { } Optional> hits = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll()); - if (!hits.isPresent()) { + if (hits.isPresent() && hits.get().isEmpty() == false) { + lastSortKey = hits.get().get(hits.get().size() - 1).getSortKey(); + } else { hasNext = false; } return hits; @@ -135,6 +138,7 @@ public class DataFrameDataExtractor { private List tryRequestWithSearchResponse(Supplier request) throws IOException { try { + // We've set allow_partial_search_results to false which means if something // goes wrong the request will throw. SearchResponse searchResponse = request.get(); @@ -165,8 +169,19 @@ public class DataFrameDataExtractor { .setAllowPartialSearchResults(false) .addSort(DestinationIndex.ID_COPY, SortOrder.ASC) .setIndices(context.indices) - .setSize(context.scrollSize) - .setQuery(context.query); + .setSize(context.scrollSize); + + if (lastSortKey == null) { + searchRequestBuilder.setQuery(context.query); + } else { + LOGGER.debug(() -> new ParameterizedMessage("[{}] Searching docs with [{}] greater than [{}]", + context.jobId, DestinationIndex.ID_COPY, lastSortKey)); + QueryBuilder queryPlusLastSortKey = QueryBuilders.boolQuery() + .filter(context.query) + .filter(QueryBuilders.rangeQuery(DestinationIndex.ID_COPY).gt(lastSortKey)); + searchRequestBuilder.setQuery(queryPlusLastSortKey); + } + setFetchSource(searchRequestBuilder); for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) { @@ -426,5 +441,9 @@ public class DataFrameDataExtractor { public int getChecksum() { return Arrays.hashCode(values); } + + public String getSortKey() { + return (String) hit.getSortValues()[0]; + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index 71b636d2cfe..4805d1d286f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -58,6 +58,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.same; @@ -77,6 +78,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { private TrainTestSplitterFactory trainTestSplitterFactory; private ArgumentCaptor capturedClearScrollRequests; private ActionFuture clearScrollFuture; + private int searchHitCounter; @Before @SuppressWarnings("unchecked") @@ -196,6 +198,13 @@ public class DataFrameDataExtractorTests extends ESTestCase { List capturedClearScrollRequests = getCapturedClearScrollIds(); assertThat(capturedClearScrollRequests.size(), equalTo(1)); assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId())); + + // Notice we've done two searches here + assertThat(dataExtractor.capturedSearchRequests, hasSize(2)); + + // Assert the second search did not include a range query as the failure happened on the very first search + String searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}")); } public void testErrorOnSearchTwiceLeadsToFailure() { @@ -215,14 +224,14 @@ public class DataFrameDataExtractorTests extends ESTestCase { TestExtractor dataExtractor = createExtractor(true, false); // Search will succeed - SearchResponse response1 = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); + SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, 1_2), Arrays.asList(2_1, 2_2)); dataExtractor.setNextResponse(response1); // But the first continue scroll fails dataExtractor.setNextResponse(createResponseWithShardFailures()); // The next one succeeds and we shall recover - SearchResponse response2 = createSearchResponse(Arrays.asList(1_2), Arrays.asList(2_2)); + SearchResponse response2 = createSearchResponse(Arrays.asList(1_3), Arrays.asList(2_3)); dataExtractor.setNextResponse(response2); // Last one @@ -234,15 +243,16 @@ public class DataFrameDataExtractorTests extends ESTestCase { // First batch expected as normally since we'll retry after the error Optional> rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); - assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().size(), equalTo(2)); assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(rows.get().get(1).getValues(), equalTo(new String[] {"12", "22"})); assertThat(dataExtractor.hasNext(), is(true)); // We get second batch as we retried after the error rows = dataExtractor.next(); assertThat(rows.isPresent(), is(true)); assertThat(rows.get().size(), equalTo(1)); - assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"12", "22"})); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"13", "23"})); assertThat(dataExtractor.hasNext(), is(true)); // Next batch should return empty @@ -254,6 +264,12 @@ public class DataFrameDataExtractorTests extends ESTestCase { assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(2)); assertThat(dataExtractor.capturedContinueScrollIds.size(), equalTo(2)); + // Assert the second search continued from the latest successfully processed doc + String searchRequest = dataExtractor.capturedSearchRequests.get(1).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"query\":{\"bool\":{")); + assertThat(searchRequest, containsString("{\"match_all\":{\"boost\":1.0}")); + assertThat(searchRequest, containsString("{\"range\":{\"ml__id_copy\":{\"from\":\"1\",\"to\":null,\"include_lower\":false")); + // Check we cleared the scroll with the latest scroll id List capturedClearScrollRequests = getCapturedClearScrollIds(); assertThat(capturedClearScrollRequests.size(), equalTo(1)); @@ -582,6 +598,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { addField(searchHitBuilder, "field_1", field1Values.get(i)); addField(searchHitBuilder, "field_2", field2Values.get(i)); searchHitBuilder.setSource("{\"field_1\":" + field1Values.get(i) + ",\"field_2\":" + field2Values.get(i) + "}"); + searchHitBuilder.setStringSortValue(String.valueOf(searchHitCounter++)); hits.add(searchHitBuilder.build()); } SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[0]), new TotalHits(hits.size(), TotalHits.Relation.EQUAL_TO), 1); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java index a112836baa9..ff8d13902ca 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/test/SearchHitBuilder.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.test; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; import java.util.Arrays; @@ -41,6 +42,11 @@ public class SearchHitBuilder { return this; } + public SearchHitBuilder setStringSortValue(String sortValue) { + hit.sortValues(new String[] { sortValue }, new DocValueFormat[] { DocValueFormat.RAW }); + return this; + } + public SearchHit build() { return hit; }