diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 2a6dbecf76e..bf0a987a3de 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -694,7 +694,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, this.modelLoadingService.set(modelLoadingService); // Data frame analytics components - AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, + AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager( + settings, + client, threadPool, analyticsProcessFactory, dataFrameAnalyticsAuditor, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java index 3e03be48929..776b83afad8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunner.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchHit; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ClientHelper; @@ -27,6 +28,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.util.Deque; @@ -40,8 +42,8 @@ public class InferenceRunner { private static final Logger LOGGER = LogManager.getLogger(InferenceRunner.class); private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98; - private static final int RESULTS_BATCH_SIZE = 1000; + private final Settings settings; private final Client client; private final ModelLoadingService modelLoadingService; private final ResultsPersisterService resultsPersisterService; @@ -52,9 +54,10 @@ public class InferenceRunner { private final DataCountsTracker dataCountsTracker; private volatile boolean isCancelled; - public InferenceRunner(Client client, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, - TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields, - ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) { + public InferenceRunner(Settings settings, Client client, ModelLoadingService modelLoadingService, + ResultsPersisterService resultsPersisterService, TaskId parentTaskId, DataFrameAnalyticsConfig config, + ExtractedFields extractedFields, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) { + this.settings = Objects.requireNonNull(settings); this.client = Objects.requireNonNull(client); this.modelLoadingService = Objects.requireNonNull(modelLoadingService); this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); @@ -92,36 +95,29 @@ public class InferenceRunner { void inferTestDocs(LocalModel model, TestDocsIterator testDocsIterator) { long totalDocCount = 0; long processedDocCount = 0; - BulkRequest bulkRequest = new BulkRequest(); - while (testDocsIterator.hasNext()) { - if (isCancelled) { - break; + try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) { + while (testDocsIterator.hasNext()) { + if (isCancelled) { + break; + } + + Deque batch = testDocsIterator.next(); + + if (totalDocCount == 0) { + totalDocCount = testDocsIterator.getTotalHits(); + } + + for (SearchHit doc : batch) { + dataCountsTracker.incrementTestDocsCount(); + InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc)); + bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField())); + + processedDocCount++; + int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION); + progressTracker.updateInferenceProgress(progressPercent); + } } - - Deque batch = testDocsIterator.next(); - - if (totalDocCount == 0) { - totalDocCount = testDocsIterator.getTotalHits(); - } - - for (SearchHit doc : batch) { - dataCountsTracker.incrementTestDocsCount(); - InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc)); - bulkRequest.add(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField())); - - processedDocCount++; - int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION); - progressTracker.updateInferenceProgress(progressPercent); - } - - if (bulkRequest.numberOfActions() == RESULTS_BATCH_SIZE) { - executeBulkRequest(bulkRequest); - bulkRequest = new BulkRequest(); - } - } - if (bulkRequest.numberOfActions() > 0 && isCancelled == false) { - executeBulkRequest(bulkRequest); } if (isCancelled == false) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 31baca95cf0..13358461e43 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -18,6 +18,7 @@ import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; @@ -61,6 +62,7 @@ public class AnalyticsProcessManager { private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class); + private final Settings settings; private final Client client; private final ExecutorService executorServiceForJob; private final ExecutorService executorServiceForProcess; @@ -72,7 +74,8 @@ public class AnalyticsProcessManager { private final ResultsPersisterService resultsPersisterService; private final int numAllocatedProcessors; - public AnalyticsProcessManager(Client client, + public AnalyticsProcessManager(Settings settings, + Client client, ThreadPool threadPool, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, @@ -81,6 +84,7 @@ public class AnalyticsProcessManager { ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) { this( + settings, client, threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), @@ -93,7 +97,8 @@ public class AnalyticsProcessManager { } // Visible for testing - public AnalyticsProcessManager(Client client, + public AnalyticsProcessManager(Settings settings, + Client client, ExecutorService executorServiceForJob, ExecutorService executorServiceForProcess, AnalyticsProcessFactory analyticsProcessFactory, @@ -102,6 +107,7 @@ public class AnalyticsProcessManager { ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) { + this.settings = Objects.requireNonNull(settings); this.client = Objects.requireNonNull(client); this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob); this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess); @@ -330,7 +336,7 @@ public class AnalyticsProcessManager { if (processContext.config.getAnalysis().supportsInference()) { refreshDest(parentTaskClient, processContext.config); - InferenceRunner inferenceRunner = new InferenceRunner(parentTaskClient, modelLoadingService, resultsPersisterService, + InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, resultsPersisterService, task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(), task.getStatsHolder().getDataCountsTracker()); processContext.setInferenceRunner(inferenceRunner); @@ -489,7 +495,7 @@ public class AnalyticsProcessManager { private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task, DataFrameDataExtractorFactory dataExtractorFactory) { DataFrameRowsJoiner dataFrameRowsJoiner = - new DataFrameRowsJoiner(config.getId(), task.getParentTaskId(), + new DataFrameRowsJoiner(config.getId(), settings, task.getParentTaskId(), dataExtractorFactory.newExtractor(true), resultsPersisterService); return new AnalyticsResultProcessor( config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java index 8ec023ec573..3a2234e0398 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -12,11 +12,13 @@ import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.search.SearchHit; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; import java.io.IOException; @@ -36,6 +38,7 @@ class DataFrameRowsJoiner implements AutoCloseable { private static final int RESULTS_BATCH_SIZE = 1000; private final String analyticsId; + private final Settings settings; private final TaskId parentTaskId; private final DataFrameDataExtractor dataExtractor; private final ResultsPersisterService resultsPersisterService; @@ -44,9 +47,10 @@ class DataFrameRowsJoiner implements AutoCloseable { private volatile String failure; private volatile boolean isCancelled; - DataFrameRowsJoiner(String analyticsId, TaskId parentTaskId, DataFrameDataExtractor dataExtractor, + DataFrameRowsJoiner(String analyticsId, Settings settings, TaskId parentTaskId, DataFrameDataExtractor dataExtractor, ResultsPersisterService resultsPersisterService) { this.analyticsId = Objects.requireNonNull(analyticsId); + this.settings = Objects.requireNonNull(settings); this.parentTaskId = Objects.requireNonNull(parentTaskId); this.dataExtractor = Objects.requireNonNull(dataExtractor); this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService); @@ -86,25 +90,28 @@ class DataFrameRowsJoiner implements AutoCloseable { } private void joinCurrentResults() { - BulkRequest bulkRequest = new BulkRequest(); - while (currentResults.isEmpty() == false) { - RowResults result = currentResults.pop(); - DataFrameDataExtractor.Row row = dataFrameRowsIterator.next(); - checkChecksumsMatch(row, result); - bulkRequest.add(createIndexRequest(result, row.getHit())); - } - if (bulkRequest.numberOfActions() > 0) { - bulkRequest.setParentTask(parentTaskId); - resultsPersisterService.bulkIndexWithHeadersWithRetry( - dataExtractor.getHeaders(), - bulkRequest, - analyticsId, - () -> isCancelled == false, - errorMsg -> {}); + try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) { + while (currentResults.isEmpty() == false) { + RowResults result = currentResults.pop(); + DataFrameDataExtractor.Row row = dataFrameRowsIterator.next(); + checkChecksumsMatch(row, result); + bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(result, row.getHit())); + } } + currentResults = new LinkedList<>(); } + private void executeBulkRequest(BulkRequest bulkRequest) { + bulkRequest.setParentTask(parentTaskId); + resultsPersisterService.bulkIndexWithHeadersWithRetry( + dataExtractor.getHeaders(), + bulkRequest, + analyticsId, + () -> isCancelled == false, + errorMsg -> {}); + } + private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults result) { if (row.getChecksum() != result.getChecksum()) { String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; "; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/LimitAwareBulkIndexer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/LimitAwareBulkIndexer.java new file mode 100644 index 00000000000..04518241d70 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/persistence/LimitAwareBulkIndexer.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.utils.persistence; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexingPressure; + +import java.util.Objects; +import java.util.function.Consumer; + +/** + * A helper class that gathers index requests in bulk requests + * that do exceed a 1000 operations or half the available memory + * limit for indexing. + */ +public class LimitAwareBulkIndexer implements AutoCloseable { + + private static final Logger LOGGER = LogManager.getLogger(LimitAwareBulkIndexer.class); + + private static final int BATCH_SIZE = 1000; + + private final long bytesLimit; + private final Consumer executor; + private BulkRequest currentBulkRequest = new BulkRequest(); + private long currentRamBytes; + + public LimitAwareBulkIndexer(Settings settings, Consumer executor) { + this((long) Math.ceil(0.5 * IndexingPressure.MAX_INDEXING_BYTES.get(settings).getBytes()), executor); + } + + LimitAwareBulkIndexer(long bytesLimit, Consumer executor) { + this.bytesLimit = bytesLimit; + this.executor = Objects.requireNonNull(executor); + } + + public void addAndExecuteIfNeeded(IndexRequest indexRequest) { + if (currentRamBytes + indexRequest.ramBytesUsed() > bytesLimit || currentBulkRequest.numberOfActions() == BATCH_SIZE) { + execute(); + } + currentBulkRequest.add(indexRequest); + currentRamBytes += indexRequest.ramBytesUsed(); + } + + private void execute() { + if (currentBulkRequest.numberOfActions() > 0) { + LOGGER.debug("Executing bulk request; current bytes [{}]; bytes limit [{}]; number of actions [{}]", + currentRamBytes, bytesLimit, currentBulkRequest.numberOfActions()); + executor.accept(currentBulkRequest); + currentBulkRequest = new BulkRequest(); + currentRamBytes = 0; + } + } + + @Override + public void close() { + execute(); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java index c54a43a5bba..a9f0ec23bfc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/inference/InferenceRunnerTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.search.SearchHit; @@ -164,7 +165,7 @@ public class InferenceRunnerTests extends ESTestCase { } private InferenceRunner createInferenceRunner(ExtractedFields extractedFields) { - return new InferenceRunner(client, modelLoadingService, resultsPersisterService, parentTaskId, config, extractedFields, - progressTracker, new DataCountsTracker()); + return new InferenceRunner(Settings.EMPTY, client, modelLoadingService, resultsPersisterService, parentTaskId, config, + extractedFields, progressTracker, new DataCountsTracker()); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index dd155e30b40..9fbf881d530 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -112,8 +112,8 @@ public class AnalyticsProcessManagerTests extends ESTestCase { resultsPersisterService = mock(ResultsPersisterService.class); modelLoadingService = mock(ModelLoadingService.class); - processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, - trainedModelProvider, modelLoadingService, resultsPersisterService, 1); + processManager = new AnalyticsProcessManager(Settings.EMPTY, client, executorServiceForJob, executorServiceForProcess, + processFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, 1); } public void testRunJob_TaskIsStopping() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java index dfa6b872eea..bd21de1ff6a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.tasks.TaskId; @@ -31,6 +32,7 @@ import java.util.Optional; import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; @@ -264,7 +266,10 @@ public class DataFrameRowsJoinerTests extends ESTestCase { RowResults result2 = new RowResults(2, resultFields); givenProcessResults(Arrays.asList(result1, result2)); - verifyNoMoreInteractions(resultsPersisterService); + List capturedBulkRequests = bulkRequestCaptor.getAllValues(); + assertThat(capturedBulkRequests, hasSize(1)); + BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); + assertThat(capturedBulkRequest.numberOfActions(), equalTo(1)); } public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException { @@ -284,7 +289,8 @@ public class DataFrameRowsJoinerTests extends ESTestCase { } private void givenProcessResults(List results) { - try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, new TaskId(""), dataExtractor, resultsPersisterService)) { + try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, Settings.EMPTY, new TaskId(""), dataExtractor, + resultsPersisterService)) { results.forEach(joiner::processRowResults); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/LimitAwareBulkIndexerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/LimitAwareBulkIndexerTests.java new file mode 100644 index 00000000000..2730a99c26b --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/LimitAwareBulkIndexerTests.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.utils.persistence; + +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.test.ESTestCase; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class LimitAwareBulkIndexerTests extends ESTestCase { + + private List executedBulkRequests = new ArrayList<>(); + + public void testAddAndExecuteIfNeeded_GivenRequestsReachingBytesLimit() { + try (LimitAwareBulkIndexer bulkIndexer = createIndexer(100)) { + bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50)); + assertThat(executedBulkRequests, is(empty())); + + bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50)); + assertThat(executedBulkRequests, is(empty())); + + bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50)); + assertThat(executedBulkRequests, hasSize(1)); + assertThat(executedBulkRequests.get(0).numberOfActions(), equalTo(2)); + + bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50)); + assertThat(executedBulkRequests, hasSize(1)); + + bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50)); + assertThat(executedBulkRequests, hasSize(2)); + assertThat(executedBulkRequests.get(1).numberOfActions(), equalTo(2)); + } + + assertThat(executedBulkRequests, hasSize(3)); + assertThat(executedBulkRequests.get(2).numberOfActions(), equalTo(1)); + } + + public void testAddAndExecuteIfNeeded_GivenRequestsReachingBatchSize() { + try (LimitAwareBulkIndexer bulkIndexer = createIndexer(10000)) { + for (int i = 0; i < 1000; i++) { + bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(1)); + } + assertThat(executedBulkRequests, is(empty())); + + bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(1)); + + assertThat(executedBulkRequests, hasSize(1)); + assertThat(executedBulkRequests.get(0).numberOfActions(), equalTo(1000)); + } + + assertThat(executedBulkRequests, hasSize(2)); + assertThat(executedBulkRequests.get(1).numberOfActions(), equalTo(1)); + } + + public void testNoRequests() { + try (LimitAwareBulkIndexer bulkIndexer = createIndexer(10000)) { + } + + assertThat(executedBulkRequests, is(empty())); + } + + private LimitAwareBulkIndexer createIndexer(long bytesLimit) { + return new LimitAwareBulkIndexer(bytesLimit, executedBulkRequests::add); + } + + private static IndexRequest mockIndexRequest(long ramBytes) { + IndexRequest indexRequest = mock(IndexRequest.class); + when(indexRequest.ramBytesUsed()).thenReturn(ramBytes); + return indexRequest; + } +} +