Data frame analytics jobs that work with very large datasets may produce bulk requests that are over the memory limit for indexing. This commit adds a helper class that bundles index requests in bulk requests that steer away from the memory limit. We then use this class both from the results joiner and the inference runner ensuring data frame analytics jobs do not generate bulk requests that are too large. Note the limit was implemented in #58885. Backport of #60219
This commit is contained in:
parent
b78caa5c00
commit
16ffcfb9f6
|
@ -694,7 +694,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
|
|||
this.modelLoadingService.set(modelLoadingService);
|
||||
|
||||
// Data frame analytics components
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client,
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(
|
||||
settings,
|
||||
client,
|
||||
threadPool,
|
||||
analyticsProcessFactory,
|
||||
dataFrameAnalyticsAuditor,
|
||||
|
|
|
@ -14,6 +14,7 @@ import org.elasticsearch.action.index.IndexRequest;
|
|||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.client.OriginSettingClient;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
|
@ -27,6 +28,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
|||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
|
||||
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
|
||||
import java.util.Deque;
|
||||
|
@ -40,8 +42,8 @@ public class InferenceRunner {
|
|||
private static final Logger LOGGER = LogManager.getLogger(InferenceRunner.class);
|
||||
|
||||
private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;
|
||||
private static final int RESULTS_BATCH_SIZE = 1000;
|
||||
|
||||
private final Settings settings;
|
||||
private final Client client;
|
||||
private final ModelLoadingService modelLoadingService;
|
||||
private final ResultsPersisterService resultsPersisterService;
|
||||
|
@ -52,9 +54,10 @@ public class InferenceRunner {
|
|||
private final DataCountsTracker dataCountsTracker;
|
||||
private volatile boolean isCancelled;
|
||||
|
||||
public InferenceRunner(Client client, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService,
|
||||
TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields,
|
||||
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) {
|
||||
public InferenceRunner(Settings settings, Client client, ModelLoadingService modelLoadingService,
|
||||
ResultsPersisterService resultsPersisterService, TaskId parentTaskId, DataFrameAnalyticsConfig config,
|
||||
ExtractedFields extractedFields, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) {
|
||||
this.settings = Objects.requireNonNull(settings);
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
|
||||
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
|
||||
|
@ -92,36 +95,29 @@ public class InferenceRunner {
|
|||
void inferTestDocs(LocalModel model, TestDocsIterator testDocsIterator) {
|
||||
long totalDocCount = 0;
|
||||
long processedDocCount = 0;
|
||||
BulkRequest bulkRequest = new BulkRequest();
|
||||
|
||||
while (testDocsIterator.hasNext()) {
|
||||
if (isCancelled) {
|
||||
break;
|
||||
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
|
||||
while (testDocsIterator.hasNext()) {
|
||||
if (isCancelled) {
|
||||
break;
|
||||
}
|
||||
|
||||
Deque<SearchHit> batch = testDocsIterator.next();
|
||||
|
||||
if (totalDocCount == 0) {
|
||||
totalDocCount = testDocsIterator.getTotalHits();
|
||||
}
|
||||
|
||||
for (SearchHit doc : batch) {
|
||||
dataCountsTracker.incrementTestDocsCount();
|
||||
InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc));
|
||||
bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField()));
|
||||
|
||||
processedDocCount++;
|
||||
int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION);
|
||||
progressTracker.updateInferenceProgress(progressPercent);
|
||||
}
|
||||
}
|
||||
|
||||
Deque<SearchHit> batch = testDocsIterator.next();
|
||||
|
||||
if (totalDocCount == 0) {
|
||||
totalDocCount = testDocsIterator.getTotalHits();
|
||||
}
|
||||
|
||||
for (SearchHit doc : batch) {
|
||||
dataCountsTracker.incrementTestDocsCount();
|
||||
InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc));
|
||||
bulkRequest.add(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField()));
|
||||
|
||||
processedDocCount++;
|
||||
int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION);
|
||||
progressTracker.updateInferenceProgress(progressPercent);
|
||||
}
|
||||
|
||||
if (bulkRequest.numberOfActions() == RESULTS_BATCH_SIZE) {
|
||||
executeBulkRequest(bulkRequest);
|
||||
bulkRequest = new BulkRequest();
|
||||
}
|
||||
}
|
||||
if (bulkRequest.numberOfActions() > 0 && isCancelled == false) {
|
||||
executeBulkRequest(bulkRequest);
|
||||
}
|
||||
|
||||
if (isCancelled == false) {
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.elasticsearch.client.ParentTaskAssigningClient;
|
|||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
|
@ -61,6 +62,7 @@ public class AnalyticsProcessManager {
|
|||
|
||||
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
|
||||
|
||||
private final Settings settings;
|
||||
private final Client client;
|
||||
private final ExecutorService executorServiceForJob;
|
||||
private final ExecutorService executorServiceForProcess;
|
||||
|
@ -72,7 +74,8 @@ public class AnalyticsProcessManager {
|
|||
private final ResultsPersisterService resultsPersisterService;
|
||||
private final int numAllocatedProcessors;
|
||||
|
||||
public AnalyticsProcessManager(Client client,
|
||||
public AnalyticsProcessManager(Settings settings,
|
||||
Client client,
|
||||
ThreadPool threadPool,
|
||||
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
||||
DataFrameAnalyticsAuditor auditor,
|
||||
|
@ -81,6 +84,7 @@ public class AnalyticsProcessManager {
|
|||
ResultsPersisterService resultsPersisterService,
|
||||
int numAllocatedProcessors) {
|
||||
this(
|
||||
settings,
|
||||
client,
|
||||
threadPool.generic(),
|
||||
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
|
||||
|
@ -93,7 +97,8 @@ public class AnalyticsProcessManager {
|
|||
}
|
||||
|
||||
// Visible for testing
|
||||
public AnalyticsProcessManager(Client client,
|
||||
public AnalyticsProcessManager(Settings settings,
|
||||
Client client,
|
||||
ExecutorService executorServiceForJob,
|
||||
ExecutorService executorServiceForProcess,
|
||||
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
||||
|
@ -102,6 +107,7 @@ public class AnalyticsProcessManager {
|
|||
ModelLoadingService modelLoadingService,
|
||||
ResultsPersisterService resultsPersisterService,
|
||||
int numAllocatedProcessors) {
|
||||
this.settings = Objects.requireNonNull(settings);
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
|
||||
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
|
||||
|
@ -330,7 +336,7 @@ public class AnalyticsProcessManager {
|
|||
|
||||
if (processContext.config.getAnalysis().supportsInference()) {
|
||||
refreshDest(parentTaskClient, processContext.config);
|
||||
InferenceRunner inferenceRunner = new InferenceRunner(parentTaskClient, modelLoadingService, resultsPersisterService,
|
||||
InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, resultsPersisterService,
|
||||
task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(),
|
||||
task.getStatsHolder().getDataCountsTracker());
|
||||
processContext.setInferenceRunner(inferenceRunner);
|
||||
|
@ -489,7 +495,7 @@ public class AnalyticsProcessManager {
|
|||
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
|
||||
DataFrameDataExtractorFactory dataExtractorFactory) {
|
||||
DataFrameRowsJoiner dataFrameRowsJoiner =
|
||||
new DataFrameRowsJoiner(config.getId(), task.getParentTaskId(),
|
||||
new DataFrameRowsJoiner(config.getId(), settings, task.getParentTaskId(),
|
||||
dataExtractorFactory.newExtractor(true), resultsPersisterService);
|
||||
return new AnalyticsResultProcessor(
|
||||
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister,
|
||||
|
|
|
@ -12,11 +12,13 @@ import org.elasticsearch.action.DocWriteRequest;
|
|||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -36,6 +38,7 @@ class DataFrameRowsJoiner implements AutoCloseable {
|
|||
private static final int RESULTS_BATCH_SIZE = 1000;
|
||||
|
||||
private final String analyticsId;
|
||||
private final Settings settings;
|
||||
private final TaskId parentTaskId;
|
||||
private final DataFrameDataExtractor dataExtractor;
|
||||
private final ResultsPersisterService resultsPersisterService;
|
||||
|
@ -44,9 +47,10 @@ class DataFrameRowsJoiner implements AutoCloseable {
|
|||
private volatile String failure;
|
||||
private volatile boolean isCancelled;
|
||||
|
||||
DataFrameRowsJoiner(String analyticsId, TaskId parentTaskId, DataFrameDataExtractor dataExtractor,
|
||||
DataFrameRowsJoiner(String analyticsId, Settings settings, TaskId parentTaskId, DataFrameDataExtractor dataExtractor,
|
||||
ResultsPersisterService resultsPersisterService) {
|
||||
this.analyticsId = Objects.requireNonNull(analyticsId);
|
||||
this.settings = Objects.requireNonNull(settings);
|
||||
this.parentTaskId = Objects.requireNonNull(parentTaskId);
|
||||
this.dataExtractor = Objects.requireNonNull(dataExtractor);
|
||||
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
|
||||
|
@ -86,25 +90,28 @@ class DataFrameRowsJoiner implements AutoCloseable {
|
|||
}
|
||||
|
||||
private void joinCurrentResults() {
|
||||
BulkRequest bulkRequest = new BulkRequest();
|
||||
while (currentResults.isEmpty() == false) {
|
||||
RowResults result = currentResults.pop();
|
||||
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
|
||||
checkChecksumsMatch(row, result);
|
||||
bulkRequest.add(createIndexRequest(result, row.getHit()));
|
||||
}
|
||||
if (bulkRequest.numberOfActions() > 0) {
|
||||
bulkRequest.setParentTask(parentTaskId);
|
||||
resultsPersisterService.bulkIndexWithHeadersWithRetry(
|
||||
dataExtractor.getHeaders(),
|
||||
bulkRequest,
|
||||
analyticsId,
|
||||
() -> isCancelled == false,
|
||||
errorMsg -> {});
|
||||
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
|
||||
while (currentResults.isEmpty() == false) {
|
||||
RowResults result = currentResults.pop();
|
||||
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
|
||||
checkChecksumsMatch(row, result);
|
||||
bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(result, row.getHit()));
|
||||
}
|
||||
}
|
||||
|
||||
currentResults = new LinkedList<>();
|
||||
}
|
||||
|
||||
private void executeBulkRequest(BulkRequest bulkRequest) {
|
||||
bulkRequest.setParentTask(parentTaskId);
|
||||
resultsPersisterService.bulkIndexWithHeadersWithRetry(
|
||||
dataExtractor.getHeaders(),
|
||||
bulkRequest,
|
||||
analyticsId,
|
||||
() -> isCancelled == false,
|
||||
errorMsg -> {});
|
||||
}
|
||||
|
||||
private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults result) {
|
||||
if (row.getChecksum() != result.getChecksum()) {
|
||||
String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; ";
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.utils.persistence;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.index.IndexingPressure;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* A helper class that gathers index requests in bulk requests
|
||||
* that do exceed a 1000 operations or half the available memory
|
||||
* limit for indexing.
|
||||
*/
|
||||
public class LimitAwareBulkIndexer implements AutoCloseable {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(LimitAwareBulkIndexer.class);
|
||||
|
||||
private static final int BATCH_SIZE = 1000;
|
||||
|
||||
private final long bytesLimit;
|
||||
private final Consumer<BulkRequest> executor;
|
||||
private BulkRequest currentBulkRequest = new BulkRequest();
|
||||
private long currentRamBytes;
|
||||
|
||||
public LimitAwareBulkIndexer(Settings settings, Consumer<BulkRequest> executor) {
|
||||
this((long) Math.ceil(0.5 * IndexingPressure.MAX_INDEXING_BYTES.get(settings).getBytes()), executor);
|
||||
}
|
||||
|
||||
LimitAwareBulkIndexer(long bytesLimit, Consumer<BulkRequest> executor) {
|
||||
this.bytesLimit = bytesLimit;
|
||||
this.executor = Objects.requireNonNull(executor);
|
||||
}
|
||||
|
||||
public void addAndExecuteIfNeeded(IndexRequest indexRequest) {
|
||||
if (currentRamBytes + indexRequest.ramBytesUsed() > bytesLimit || currentBulkRequest.numberOfActions() == BATCH_SIZE) {
|
||||
execute();
|
||||
}
|
||||
currentBulkRequest.add(indexRequest);
|
||||
currentRamBytes += indexRequest.ramBytesUsed();
|
||||
}
|
||||
|
||||
private void execute() {
|
||||
if (currentBulkRequest.numberOfActions() > 0) {
|
||||
LOGGER.debug("Executing bulk request; current bytes [{}]; bytes limit [{}]; number of actions [{}]",
|
||||
currentRamBytes, bytesLimit, currentBulkRequest.numberOfActions());
|
||||
executor.accept(currentBulkRequest);
|
||||
currentBulkRequest = new BulkRequest();
|
||||
currentRamBytes = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
execute();
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
|
|||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.bytes.BytesReference;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
|
@ -164,7 +165,7 @@ public class InferenceRunnerTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private InferenceRunner createInferenceRunner(ExtractedFields extractedFields) {
|
||||
return new InferenceRunner(client, modelLoadingService, resultsPersisterService, parentTaskId, config, extractedFields,
|
||||
progressTracker, new DataCountsTracker());
|
||||
return new InferenceRunner(Settings.EMPTY, client, modelLoadingService, resultsPersisterService, parentTaskId, config,
|
||||
extractedFields, progressTracker, new DataCountsTracker());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -112,8 +112,8 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
|||
|
||||
resultsPersisterService = mock(ResultsPersisterService.class);
|
||||
modelLoadingService = mock(ModelLoadingService.class);
|
||||
processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
|
||||
trainedModelProvider, modelLoadingService, resultsPersisterService, 1);
|
||||
processManager = new AnalyticsProcessManager(Settings.EMPTY, client, executorServiceForJob, executorServiceForProcess,
|
||||
processFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, 1);
|
||||
}
|
||||
|
||||
public void testRunJob_TaskIsStopping() {
|
||||
|
|
|
@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
|
|||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
|
@ -31,6 +32,7 @@ import java.util.Optional;
|
|||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
@ -264,7 +266,10 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
|
|||
RowResults result2 = new RowResults(2, resultFields);
|
||||
givenProcessResults(Arrays.asList(result1, result2));
|
||||
|
||||
verifyNoMoreInteractions(resultsPersisterService);
|
||||
List<BulkRequest> capturedBulkRequests = bulkRequestCaptor.getAllValues();
|
||||
assertThat(capturedBulkRequests, hasSize(1));
|
||||
BulkRequest capturedBulkRequest = capturedBulkRequests.get(0);
|
||||
assertThat(capturedBulkRequest.numberOfActions(), equalTo(1));
|
||||
}
|
||||
|
||||
public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
|
||||
|
@ -284,7 +289,8 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private void givenProcessResults(List<RowResults> results) {
|
||||
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, new TaskId(""), dataExtractor, resultsPersisterService)) {
|
||||
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, Settings.EMPTY, new TaskId(""), dataExtractor,
|
||||
resultsPersisterService)) {
|
||||
results.forEach(joiner::processRowResults);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.utils.persistence;
|
||||
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class LimitAwareBulkIndexerTests extends ESTestCase {
|
||||
|
||||
private List<BulkRequest> executedBulkRequests = new ArrayList<>();
|
||||
|
||||
public void testAddAndExecuteIfNeeded_GivenRequestsReachingBytesLimit() {
|
||||
try (LimitAwareBulkIndexer bulkIndexer = createIndexer(100)) {
|
||||
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
|
||||
assertThat(executedBulkRequests, is(empty()));
|
||||
|
||||
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
|
||||
assertThat(executedBulkRequests, is(empty()));
|
||||
|
||||
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
|
||||
assertThat(executedBulkRequests, hasSize(1));
|
||||
assertThat(executedBulkRequests.get(0).numberOfActions(), equalTo(2));
|
||||
|
||||
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
|
||||
assertThat(executedBulkRequests, hasSize(1));
|
||||
|
||||
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
|
||||
assertThat(executedBulkRequests, hasSize(2));
|
||||
assertThat(executedBulkRequests.get(1).numberOfActions(), equalTo(2));
|
||||
}
|
||||
|
||||
assertThat(executedBulkRequests, hasSize(3));
|
||||
assertThat(executedBulkRequests.get(2).numberOfActions(), equalTo(1));
|
||||
}
|
||||
|
||||
public void testAddAndExecuteIfNeeded_GivenRequestsReachingBatchSize() {
|
||||
try (LimitAwareBulkIndexer bulkIndexer = createIndexer(10000)) {
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(1));
|
||||
}
|
||||
assertThat(executedBulkRequests, is(empty()));
|
||||
|
||||
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(1));
|
||||
|
||||
assertThat(executedBulkRequests, hasSize(1));
|
||||
assertThat(executedBulkRequests.get(0).numberOfActions(), equalTo(1000));
|
||||
}
|
||||
|
||||
assertThat(executedBulkRequests, hasSize(2));
|
||||
assertThat(executedBulkRequests.get(1).numberOfActions(), equalTo(1));
|
||||
}
|
||||
|
||||
public void testNoRequests() {
|
||||
try (LimitAwareBulkIndexer bulkIndexer = createIndexer(10000)) {
|
||||
}
|
||||
|
||||
assertThat(executedBulkRequests, is(empty()));
|
||||
}
|
||||
|
||||
private LimitAwareBulkIndexer createIndexer(long bytesLimit) {
|
||||
return new LimitAwareBulkIndexer(bytesLimit, executedBulkRequests::add);
|
||||
}
|
||||
|
||||
private static IndexRequest mockIndexRequest(long ramBytes) {
|
||||
IndexRequest indexRequest = mock(IndexRequest.class);
|
||||
when(indexRequest.ramBytesUsed()).thenReturn(ramBytes);
|
||||
return indexRequest;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue