[7.x][ML] Ensure bulk requests are not over memory limit (#60219) (#60283)

Data frame analytics jobs that work with very large datasets
may produce bulk requests that are over the memory limit
for indexing. This commit adds a helper class that bundles
index requests in bulk requests that steer away from the
memory limit. We then use this class both from the results
joiner and the inference runner ensuring data frame analytics
jobs do not generate bulk requests that are too large.

Note the limit was implemented in #58885.

Backport of #60219
This commit is contained in:
Dimitris Athanasiou 2020-07-28 16:04:03 +03:00 committed by GitHub
parent b78caa5c00
commit 16ffcfb9f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 228 additions and 59 deletions

View File

@ -694,7 +694,9 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
this.modelLoadingService.set(modelLoadingService);
// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client,
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(
settings,
client,
threadPool,
analyticsProcessFactory,
dataFrameAnalyticsAuditor,

View File

@ -14,6 +14,7 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ClientHelper;
@ -27,6 +28,7 @@ import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import java.util.Deque;
@ -40,8 +42,8 @@ public class InferenceRunner {
private static final Logger LOGGER = LogManager.getLogger(InferenceRunner.class);
private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;
private static final int RESULTS_BATCH_SIZE = 1000;
private final Settings settings;
private final Client client;
private final ModelLoadingService modelLoadingService;
private final ResultsPersisterService resultsPersisterService;
@ -52,9 +54,10 @@ public class InferenceRunner {
private final DataCountsTracker dataCountsTracker;
private volatile boolean isCancelled;
public InferenceRunner(Client client, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService,
TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields,
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) {
public InferenceRunner(Settings settings, Client client, ModelLoadingService modelLoadingService,
ResultsPersisterService resultsPersisterService, TaskId parentTaskId, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) {
this.settings = Objects.requireNonNull(settings);
this.client = Objects.requireNonNull(client);
this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
@ -92,36 +95,29 @@ public class InferenceRunner {
void inferTestDocs(LocalModel model, TestDocsIterator testDocsIterator) {
long totalDocCount = 0;
long processedDocCount = 0;
BulkRequest bulkRequest = new BulkRequest();
while (testDocsIterator.hasNext()) {
if (isCancelled) {
break;
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
while (testDocsIterator.hasNext()) {
if (isCancelled) {
break;
}
Deque<SearchHit> batch = testDocsIterator.next();
if (totalDocCount == 0) {
totalDocCount = testDocsIterator.getTotalHits();
}
for (SearchHit doc : batch) {
dataCountsTracker.incrementTestDocsCount();
InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc));
bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField()));
processedDocCount++;
int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION);
progressTracker.updateInferenceProgress(progressPercent);
}
}
Deque<SearchHit> batch = testDocsIterator.next();
if (totalDocCount == 0) {
totalDocCount = testDocsIterator.getTotalHits();
}
for (SearchHit doc : batch) {
dataCountsTracker.incrementTestDocsCount();
InferenceResults inferenceResults = model.inferNoStats(featuresFromDoc(doc));
bulkRequest.add(createIndexRequest(doc, inferenceResults, config.getDest().getResultsField()));
processedDocCount++;
int progressPercent = Math.min((int) (processedDocCount * 100.0 / totalDocCount), MAX_PROGRESS_BEFORE_COMPLETION);
progressTracker.updateInferenceProgress(progressPercent);
}
if (bulkRequest.numberOfActions() == RESULTS_BATCH_SIZE) {
executeBulkRequest(bulkRequest);
bulkRequest = new BulkRequest();
}
}
if (bulkRequest.numberOfActions() > 0 && isCancelled == false) {
executeBulkRequest(bulkRequest);
}
if (isCancelled == false) {

View File

@ -18,6 +18,7 @@ import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
@ -61,6 +62,7 @@ public class AnalyticsProcessManager {
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
private final Settings settings;
private final Client client;
private final ExecutorService executorServiceForJob;
private final ExecutorService executorServiceForProcess;
@ -72,7 +74,8 @@ public class AnalyticsProcessManager {
private final ResultsPersisterService resultsPersisterService;
private final int numAllocatedProcessors;
public AnalyticsProcessManager(Client client,
public AnalyticsProcessManager(Settings settings,
Client client,
ThreadPool threadPool,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
@ -81,6 +84,7 @@ public class AnalyticsProcessManager {
ResultsPersisterService resultsPersisterService,
int numAllocatedProcessors) {
this(
settings,
client,
threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
@ -93,7 +97,8 @@ public class AnalyticsProcessManager {
}
// Visible for testing
public AnalyticsProcessManager(Client client,
public AnalyticsProcessManager(Settings settings,
Client client,
ExecutorService executorServiceForJob,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
@ -102,6 +107,7 @@ public class AnalyticsProcessManager {
ModelLoadingService modelLoadingService,
ResultsPersisterService resultsPersisterService,
int numAllocatedProcessors) {
this.settings = Objects.requireNonNull(settings);
this.client = Objects.requireNonNull(client);
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
@ -330,7 +336,7 @@ public class AnalyticsProcessManager {
if (processContext.config.getAnalysis().supportsInference()) {
refreshDest(parentTaskClient, processContext.config);
InferenceRunner inferenceRunner = new InferenceRunner(parentTaskClient, modelLoadingService, resultsPersisterService,
InferenceRunner inferenceRunner = new InferenceRunner(settings, parentTaskClient, modelLoadingService, resultsPersisterService,
task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker());
processContext.setInferenceRunner(inferenceRunner);
@ -489,7 +495,7 @@ public class AnalyticsProcessManager {
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
DataFrameDataExtractorFactory dataExtractorFactory) {
DataFrameRowsJoiner dataFrameRowsJoiner =
new DataFrameRowsJoiner(config.getId(), task.getParentTaskId(),
new DataFrameRowsJoiner(config.getId(), settings, task.getParentTaskId(),
dataExtractorFactory.newExtractor(true), resultsPersisterService);
return new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, statsPersister,

View File

@ -12,11 +12,13 @@ import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import java.io.IOException;
@ -36,6 +38,7 @@ class DataFrameRowsJoiner implements AutoCloseable {
private static final int RESULTS_BATCH_SIZE = 1000;
private final String analyticsId;
private final Settings settings;
private final TaskId parentTaskId;
private final DataFrameDataExtractor dataExtractor;
private final ResultsPersisterService resultsPersisterService;
@ -44,9 +47,10 @@ class DataFrameRowsJoiner implements AutoCloseable {
private volatile String failure;
private volatile boolean isCancelled;
DataFrameRowsJoiner(String analyticsId, TaskId parentTaskId, DataFrameDataExtractor dataExtractor,
DataFrameRowsJoiner(String analyticsId, Settings settings, TaskId parentTaskId, DataFrameDataExtractor dataExtractor,
ResultsPersisterService resultsPersisterService) {
this.analyticsId = Objects.requireNonNull(analyticsId);
this.settings = Objects.requireNonNull(settings);
this.parentTaskId = Objects.requireNonNull(parentTaskId);
this.dataExtractor = Objects.requireNonNull(dataExtractor);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
@ -86,25 +90,28 @@ class DataFrameRowsJoiner implements AutoCloseable {
}
private void joinCurrentResults() {
BulkRequest bulkRequest = new BulkRequest();
while (currentResults.isEmpty() == false) {
RowResults result = currentResults.pop();
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
checkChecksumsMatch(row, result);
bulkRequest.add(createIndexRequest(result, row.getHit()));
}
if (bulkRequest.numberOfActions() > 0) {
bulkRequest.setParentTask(parentTaskId);
resultsPersisterService.bulkIndexWithHeadersWithRetry(
dataExtractor.getHeaders(),
bulkRequest,
analyticsId,
() -> isCancelled == false,
errorMsg -> {});
try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(settings, this::executeBulkRequest)) {
while (currentResults.isEmpty() == false) {
RowResults result = currentResults.pop();
DataFrameDataExtractor.Row row = dataFrameRowsIterator.next();
checkChecksumsMatch(row, result);
bulkIndexer.addAndExecuteIfNeeded(createIndexRequest(result, row.getHit()));
}
}
currentResults = new LinkedList<>();
}
private void executeBulkRequest(BulkRequest bulkRequest) {
bulkRequest.setParentTask(parentTaskId);
resultsPersisterService.bulkIndexWithHeadersWithRetry(
dataExtractor.getHeaders(),
bulkRequest,
analyticsId,
() -> isCancelled == false,
errorMsg -> {});
}
private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults result) {
if (row.getChecksum() != result.getChecksum()) {
String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; ";

View File

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.utils.persistence;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexingPressure;
import java.util.Objects;
import java.util.function.Consumer;
/**
* A helper class that gathers index requests in bulk requests
* that do exceed a 1000 operations or half the available memory
* limit for indexing.
*/
public class LimitAwareBulkIndexer implements AutoCloseable {
private static final Logger LOGGER = LogManager.getLogger(LimitAwareBulkIndexer.class);
private static final int BATCH_SIZE = 1000;
private final long bytesLimit;
private final Consumer<BulkRequest> executor;
private BulkRequest currentBulkRequest = new BulkRequest();
private long currentRamBytes;
public LimitAwareBulkIndexer(Settings settings, Consumer<BulkRequest> executor) {
this((long) Math.ceil(0.5 * IndexingPressure.MAX_INDEXING_BYTES.get(settings).getBytes()), executor);
}
LimitAwareBulkIndexer(long bytesLimit, Consumer<BulkRequest> executor) {
this.bytesLimit = bytesLimit;
this.executor = Objects.requireNonNull(executor);
}
public void addAndExecuteIfNeeded(IndexRequest indexRequest) {
if (currentRamBytes + indexRequest.ramBytesUsed() > bytesLimit || currentBulkRequest.numberOfActions() == BATCH_SIZE) {
execute();
}
currentBulkRequest.add(indexRequest);
currentRamBytes += indexRequest.ramBytesUsed();
}
private void execute() {
if (currentBulkRequest.numberOfActions() > 0) {
LOGGER.debug("Executing bulk request; current bytes [{}]; bytes limit [{}]; number of actions [{}]",
currentRamBytes, bytesLimit, currentBulkRequest.numberOfActions());
executor.accept(currentBulkRequest);
currentBulkRequest = new BulkRequest();
currentRamBytes = 0;
}
}
@Override
public void close() {
execute();
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.search.SearchHit;
@ -164,7 +165,7 @@ public class InferenceRunnerTests extends ESTestCase {
}
private InferenceRunner createInferenceRunner(ExtractedFields extractedFields) {
return new InferenceRunner(client, modelLoadingService, resultsPersisterService, parentTaskId, config, extractedFields,
progressTracker, new DataCountsTracker());
return new InferenceRunner(Settings.EMPTY, client, modelLoadingService, resultsPersisterService, parentTaskId, config,
extractedFields, progressTracker, new DataCountsTracker());
}
}

View File

@ -112,8 +112,8 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
resultsPersisterService = mock(ResultsPersisterService.class);
modelLoadingService = mock(ModelLoadingService.class);
processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
trainedModelProvider, modelLoadingService, resultsPersisterService, 1);
processManager = new AnalyticsProcessManager(Settings.EMPTY, client, executorServiceForJob, executorServiceForProcess,
processFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, 1);
}
public void testRunJob_TaskIsStopping() {

View File

@ -10,6 +10,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.TaskId;
@ -31,6 +32,7 @@ import java.util.Optional;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
@ -264,7 +266,10 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
RowResults result2 = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result1, result2));
verifyNoMoreInteractions(resultsPersisterService);
List<BulkRequest> capturedBulkRequests = bulkRequestCaptor.getAllValues();
assertThat(capturedBulkRequests, hasSize(1));
BulkRequest capturedBulkRequest = capturedBulkRequests.get(0);
assertThat(capturedBulkRequest.numberOfActions(), equalTo(1));
}
public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
@ -284,7 +289,8 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
}
private void givenProcessResults(List<RowResults> results) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, new TaskId(""), dataExtractor, resultsPersisterService)) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, Settings.EMPTY, new TaskId(""), dataExtractor,
resultsPersisterService)) {
results.forEach(joiner::processRowResults);
}
}

View File

@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.utils.persistence;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList;
import java.util.List;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class LimitAwareBulkIndexerTests extends ESTestCase {
private List<BulkRequest> executedBulkRequests = new ArrayList<>();
public void testAddAndExecuteIfNeeded_GivenRequestsReachingBytesLimit() {
try (LimitAwareBulkIndexer bulkIndexer = createIndexer(100)) {
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
assertThat(executedBulkRequests, is(empty()));
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
assertThat(executedBulkRequests, is(empty()));
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
assertThat(executedBulkRequests, hasSize(1));
assertThat(executedBulkRequests.get(0).numberOfActions(), equalTo(2));
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
assertThat(executedBulkRequests, hasSize(1));
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(50));
assertThat(executedBulkRequests, hasSize(2));
assertThat(executedBulkRequests.get(1).numberOfActions(), equalTo(2));
}
assertThat(executedBulkRequests, hasSize(3));
assertThat(executedBulkRequests.get(2).numberOfActions(), equalTo(1));
}
public void testAddAndExecuteIfNeeded_GivenRequestsReachingBatchSize() {
try (LimitAwareBulkIndexer bulkIndexer = createIndexer(10000)) {
for (int i = 0; i < 1000; i++) {
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(1));
}
assertThat(executedBulkRequests, is(empty()));
bulkIndexer.addAndExecuteIfNeeded(mockIndexRequest(1));
assertThat(executedBulkRequests, hasSize(1));
assertThat(executedBulkRequests.get(0).numberOfActions(), equalTo(1000));
}
assertThat(executedBulkRequests, hasSize(2));
assertThat(executedBulkRequests.get(1).numberOfActions(), equalTo(1));
}
public void testNoRequests() {
try (LimitAwareBulkIndexer bulkIndexer = createIndexer(10000)) {
}
assertThat(executedBulkRequests, is(empty()));
}
private LimitAwareBulkIndexer createIndexer(long bytesLimit) {
return new LimitAwareBulkIndexer(bytesLimit, executedBulkRequests::add);
}
private static IndexRequest mockIndexRequest(long ramBytes) {
IndexRequest indexRequest = mock(IndexRequest.class);
when(indexRequest.ramBytesUsed()).thenReturn(ramBytes);
return indexRequest;
}
}