diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractor.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractor.java index aaee3f2528c..2e0fec5da3e 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractor.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractor.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.search.ClearScrollAction; +import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchRequestBuilder; @@ -214,13 +215,15 @@ class ScrollDataExtractor implements DataExtractor { } private void resetScroll() { - if (scrollId != null) { - clearScroll(scrollId); - } + clearScroll(scrollId); scrollId = null; } - void clearScroll(String scrollId) { - ClearScrollAction.INSTANCE.newRequestBuilder(client).addScrollId(scrollId).get(); + private void clearScroll(String scrollId) { + if (scrollId != null) { + ClearScrollRequest request = new ClearScrollRequest(); + request.addScrollId(scrollId); + client.execute(ClearScrollAction.INSTANCE, request).actionGet(); + } } } diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java index fe900755044..a6de545db92 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractorTests.java @@ -5,6 +5,10 @@ */ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.search.ClearScrollAction; +import org.elasticsearch.action.search.ClearScrollRequest; +import org.elasticsearch.action.search.ClearScrollResponse; import org.elasticsearch.action.search.SearchPhaseExecutionException; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; @@ -21,6 +25,7 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ESTestCase; import org.junit.Before; +import org.mockito.ArgumentCaptor; import java.io.BufferedReader; import java.io.IOException; @@ -42,6 +47,7 @@ import static java.util.Collections.emptyMap; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -50,7 +56,7 @@ public class ScrollDataExtractorTests extends ESTestCase { private Client client; private List capturedSearchRequests; private List capturedContinueScrollIds; - private List capturedClearScrollIds; + private ArgumentCaptor capturedClearScrollRequests; private String jobId; private ExtractedFields extractedFields; private List types; @@ -59,6 +65,7 @@ public class ScrollDataExtractorTests extends ESTestCase { private List scriptFields; private int scrollSize; private long initScrollStartTime; + private ActionFuture clearScrollFuture; private class TestDataExtractor extends ScrollDataExtractor { @@ -95,11 +102,6 @@ public class ScrollDataExtractorTests extends ESTestCase { } } - @Override - void clearScroll(String scrollId) { - capturedClearScrollIds.add(scrollId); - } - void setNextResponse(SearchResponse searchResponse) { responses.add(searchResponse); } @@ -118,7 +120,6 @@ public class ScrollDataExtractorTests extends ESTestCase { client = mock(Client.class); capturedSearchRequests = new ArrayList<>(); capturedContinueScrollIds = new ArrayList<>(); - capturedClearScrollIds = new ArrayList<>(); jobId = "test-job"; ExtractedField timeField = ExtractedField.newField("time", ExtractedField.ExtractionMethod.DOC_VALUE); extractedFields = new ExtractedFields(timeField, @@ -128,6 +129,10 @@ public class ScrollDataExtractorTests extends ESTestCase { query = QueryBuilders.matchAllQuery(); scriptFields = Collections.emptyList(); scrollSize = 1000; + + clearScrollFuture = mock(ActionFuture.class); + capturedClearScrollRequests = ArgumentCaptor.forClass(ClearScrollRequest.class); + when(client.execute(same(ClearScrollAction.INSTANCE), capturedClearScrollRequests.capture())).thenReturn(clearScrollFuture); } public void testSinglePageExtraction() throws IOException { @@ -164,6 +169,7 @@ public class ScrollDataExtractorTests extends ESTestCase { assertThat(capturedContinueScrollIds.size(), equalTo(1)); assertThat(capturedContinueScrollIds.get(0), equalTo(response1.getScrollId())); + List capturedClearScrollIds = getCapturedClearScrollIds(); assertThat(capturedClearScrollIds.size(), equalTo(1)); assertThat(capturedClearScrollIds.get(0), equalTo(response2.getScrollId())); } @@ -215,6 +221,7 @@ public class ScrollDataExtractorTests extends ESTestCase { assertThat(capturedContinueScrollIds.get(0), equalTo(response1.getScrollId())); assertThat(capturedContinueScrollIds.get(1), equalTo(response2.getScrollId())); + List capturedClearScrollIds = getCapturedClearScrollIds(); assertThat(capturedClearScrollIds.size(), equalTo(1)); assertThat(capturedClearScrollIds.get(0), equalTo(response3.getScrollId())); } @@ -252,6 +259,7 @@ public class ScrollDataExtractorTests extends ESTestCase { assertThat(asString(stream.get()), equalTo(expectedStream)); assertThat(extractor.hasNext(), is(false)); + List capturedClearScrollIds = getCapturedClearScrollIds(); assertThat(capturedClearScrollIds.size(), equalTo(1)); assertThat(capturedClearScrollIds.get(0), equalTo(response2.getScrollId())); } @@ -392,6 +400,7 @@ public class ScrollDataExtractorTests extends ESTestCase { expectThrows(IOException.class, () -> extractor.next()); + List capturedClearScrollIds = getCapturedClearScrollIds(); assertThat(capturedClearScrollIds.isEmpty(), is(true)); } @@ -445,6 +454,7 @@ public class ScrollDataExtractorTests extends ESTestCase { assertThat(capturedContinueScrollIds.size(), equalTo(1)); assertThat(capturedContinueScrollIds.get(0), equalTo(response1.getScrollId())); + List capturedClearScrollIds = getCapturedClearScrollIds(); assertThat(capturedClearScrollIds.size(), equalTo(1)); assertThat(capturedClearScrollIds.get(0), equalTo(response2.getScrollId())); } @@ -500,6 +510,10 @@ public class ScrollDataExtractorTests extends ESTestCase { return searchResponse; } + private List getCapturedClearScrollIds() { + return capturedClearScrollRequests.getAllValues().stream().map(r -> r.getScrollIds().get(0)).collect(Collectors.toList()); + } + private static String asString(InputStream inputStream) throws IOException { try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { return reader.lines().collect(Collectors.joining("\n"));