[7.x][ML] Retry persisting DF Analytics results (#52048) (#52160)

Employs `ResultsPersisterService` from `DataFrameRowsJoiner` in order
to add retries when a data frame analytics job is persisting the results
to the destination data frame.

Backport of #52048
This commit is contained in:
Dimitris Athanasiou 2020-02-11 09:55:00 +02:00 committed by GitHub
parent 2f1631d9d0
commit cbebc26f50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 46 deletions

View File

@ -634,7 +634,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
// Data frame analytics components // Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory, AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
dataFrameAnalyticsAuditor, trainedModelProvider); dataFrameAnalyticsAuditor, trainedModelProvider, resultsPersisterService);
MemoryUsageEstimationProcessManager memoryEstimationProcessManager = MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
new MemoryUsageEstimationProcessManager( new MemoryUsageEstimationProcessManager(
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory); threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);

View File

@ -38,6 +38,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -62,19 +63,22 @@ public class AnalyticsProcessManager {
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>(); private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
private final DataFrameAnalyticsAuditor auditor; private final DataFrameAnalyticsAuditor auditor;
private final TrainedModelProvider trainedModelProvider; private final TrainedModelProvider trainedModelProvider;
private final ResultsPersisterService resultsPersisterService;
public AnalyticsProcessManager(Client client, public AnalyticsProcessManager(Client client,
ThreadPool threadPool, ThreadPool threadPool,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory, AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) { TrainedModelProvider trainedModelProvider,
ResultsPersisterService resultsPersisterService) {
this( this(
client, client,
threadPool.generic(), threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
analyticsProcessFactory, analyticsProcessFactory,
auditor, auditor,
trainedModelProvider); trainedModelProvider,
resultsPersisterService);
} }
// Visible for testing // Visible for testing
@ -83,13 +87,15 @@ public class AnalyticsProcessManager {
ExecutorService executorServiceForProcess, ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory, AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor, DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) { TrainedModelProvider trainedModelProvider,
ResultsPersisterService resultsPersisterService) {
this.client = Objects.requireNonNull(client); this.client = Objects.requireNonNull(client);
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob); this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess); this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory); this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
this.auditor = Objects.requireNonNull(auditor); this.auditor = Objects.requireNonNull(auditor);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
} }
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) { public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) {
@ -419,7 +425,7 @@ public class AnalyticsProcessManager {
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task, private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
DataFrameDataExtractorFactory dataExtractorFactory) { DataFrameDataExtractorFactory dataExtractorFactory) {
DataFrameRowsJoiner dataFrameRowsJoiner = DataFrameRowsJoiner dataFrameRowsJoiner =
new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
return new AnalyticsResultProcessor( return new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames()); config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
} }

View File

@ -9,17 +9,14 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -38,16 +35,17 @@ class DataFrameRowsJoiner implements AutoCloseable {
private static final int RESULTS_BATCH_SIZE = 1000; private static final int RESULTS_BATCH_SIZE = 1000;
private final String analyticsId; private final String analyticsId;
private final Client client;
private final DataFrameDataExtractor dataExtractor; private final DataFrameDataExtractor dataExtractor;
private final ResultsPersisterService resultsPersisterService;
private final Iterator<DataFrameDataExtractor.Row> dataFrameRowsIterator; private final Iterator<DataFrameDataExtractor.Row> dataFrameRowsIterator;
private LinkedList<RowResults> currentResults; private LinkedList<RowResults> currentResults;
private volatile String failure; private volatile String failure;
DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) { DataFrameRowsJoiner(String analyticsId, DataFrameDataExtractor dataExtractor,
ResultsPersisterService resultsPersisterService) {
this.analyticsId = Objects.requireNonNull(analyticsId); this.analyticsId = Objects.requireNonNull(analyticsId);
this.client = Objects.requireNonNull(client);
this.dataExtractor = Objects.requireNonNull(dataExtractor); this.dataExtractor = Objects.requireNonNull(dataExtractor);
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
this.dataFrameRowsIterator = new ResultMatchingDataFrameRows(); this.dataFrameRowsIterator = new ResultMatchingDataFrameRows();
this.currentResults = new LinkedList<>(); this.currentResults = new LinkedList<>();
} }
@ -88,7 +86,8 @@ class DataFrameRowsJoiner implements AutoCloseable {
bulkRequest.add(createIndexRequest(result, row.getHit())); bulkRequest.add(createIndexRequest(result, row.getHit()));
} }
if (bulkRequest.numberOfActions() > 0) { if (bulkRequest.numberOfActions() > 0) {
executeBulkRequest(bulkRequest); resultsPersisterService.bulkIndexWithHeadersWithRetry(
dataExtractor.getHeaders(), bulkRequest, analyticsId, () -> true, errorMsg -> {});
} }
currentResults = new LinkedList<>(); currentResults = new LinkedList<>();
} }
@ -113,14 +112,6 @@ class DataFrameRowsJoiner implements AutoCloseable {
return indexRequest; return indexRequest;
} }
private void executeBulkRequest(BulkRequest bulkRequest) {
BulkResponse bulkResponse = ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), ClientHelper.ML_ORIGIN, client,
() -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet());
if (bulkResponse.hasFailures()) {
throw ExceptionsHelper.serverError("failures while writing results [" + bulkResponse.buildFailureMessage() + "]");
}
}
@Override @Override
public void close() { public void close() {
try { try {

View File

@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
@ -27,13 +28,16 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ClientHelper;
import java.io.IOException; import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -95,9 +99,28 @@ public class ResultsPersisterService {
String jobId, String jobId,
Supplier<Boolean> shouldRetry, Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler) { Consumer<String> msgHandler) {
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
providedBulkRequest -> client.bulk(providedBulkRequest).actionGet());
}
public BulkResponse bulkIndexWithHeadersWithRetry(Map<String, String> headers,
BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler) {
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
providedBulkRequest -> ClientHelper.executeWithHeaders(
headers, ClientHelper.ML_ORIGIN, client, () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()));
}
private BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
String jobId,
Supplier<Boolean> shouldRetry,
Consumer<String> msgHandler,
Function<BulkRequest, BulkResponse> actionExecutor) {
RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler); RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler);
while (true) { while (true) {
BulkResponse bulkResponse = client.bulk(bulkRequest).actionGet(); BulkResponse bulkResponse = actionExecutor.apply(bulkRequest);
if (bulkResponse.hasFailures() == false) { if (bulkResponse.hasFailures() == false) {
return bulkResponse; return bulkResponse;
} }

View File

@ -23,6 +23,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.Before; import org.junit.Before;
import org.mockito.InOrder; import org.mockito.InOrder;
@ -65,6 +66,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig; private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
private DataFrameDataExtractorFactory dataExtractorFactory; private DataFrameDataExtractorFactory dataExtractorFactory;
private DataFrameDataExtractor dataExtractor; private DataFrameDataExtractor dataExtractor;
private ResultsPersisterService resultsPersisterService;
private AnalyticsProcessManager processManager; private AnalyticsProcessManager processManager;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -97,8 +99,10 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class)); when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
processManager = new AnalyticsProcessManager( resultsPersisterService = mock(ResultsPersisterService.class);
client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider);
processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
trainedModelProvider, resultsPersisterService);
} }
public void testRunJob_TaskIsStopping() { public void testRunJob_TaskIsStopping() {

View File

@ -5,22 +5,17 @@
*/ */
package org.elasticsearch.xpack.ml.dataframe.process; package org.elasticsearch.xpack.ml.dataframe.process;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.text.Text; import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -35,7 +30,8 @@ import java.util.Optional;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.same; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -46,19 +42,22 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
private static final String ANALYTICS_ID = "my_analytics"; private static final String ANALYTICS_ID = "my_analytics";
private Client client; private static final Map<String, String> HEADERS = Collections.singletonMap("foo", "bar");
private DataFrameDataExtractor dataExtractor; private DataFrameDataExtractor dataExtractor;
private ResultsPersisterService resultsPersisterService;
private ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); private ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
@Before @Before
public void setUpMocks() { public void setUpMocks() {
client = mock(Client.class);
dataExtractor = mock(DataFrameDataExtractor.class); dataExtractor = mock(DataFrameDataExtractor.class);
when(dataExtractor.getHeaders()).thenReturn(HEADERS);
resultsPersisterService = mock(ResultsPersisterService.class);
} }
public void testProcess_GivenNoResults() { public void testProcess_GivenNoResults() {
givenProcessResults(Collections.emptyList()); givenProcessResults(Collections.emptyList());
verifyNoMoreInteractions(client); verifyNoMoreInteractions(resultsPersisterService);
} }
public void testProcess_GivenSingleRowAndResult() throws IOException { public void testProcess_GivenSingleRowAndResult() throws IOException {
@ -126,7 +125,7 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
RowResults result = new RowResults(2, resultFields); RowResults result = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result)); givenProcessResults(Arrays.asList(result));
verifyNoMoreInteractions(client); verifyNoMoreInteractions(resultsPersisterService);
} }
public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException { public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException {
@ -204,7 +203,7 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
RowResults result2 = new RowResults(2, resultFields); RowResults result2 = new RowResults(2, resultFields);
givenProcessResults(Arrays.asList(result1, result2)); givenProcessResults(Arrays.asList(result1, result2));
verifyNoMoreInteractions(client); verifyNoMoreInteractions(resultsPersisterService);
} }
public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException { public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
@ -218,13 +217,13 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
givenProcessResults(Collections.emptyList()); givenProcessResults(Collections.emptyList());
verifyNoMoreInteractions(client); verifyNoMoreInteractions(resultsPersisterService);
verify(dataExtractor).cancel(); verify(dataExtractor).cancel();
verify(dataExtractor, times(2)).next(); verify(dataExtractor, times(2)).next();
} }
private void givenProcessResults(List<RowResults> results) { private void givenProcessResults(List<RowResults> results) {
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor)) { try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, dataExtractor, resultsPersisterService)) {
results.forEach(joiner::processRowResults); results.forEach(joiner::processRowResults);
} }
} }
@ -251,14 +250,9 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
} }
private void givenClientHasNoFailures() { private void givenClientHasNoFailures() {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(resultsPersisterService.bulkIndexWithHeadersWithRetry(
ThreadPool threadPool = mock(ThreadPool.class); eq(HEADERS), bulkRequestCaptor.capture(), eq(ANALYTICS_ID), any(), any()))
when(threadPool.getThreadContext()).thenReturn(threadContext); .thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
@SuppressWarnings("unchecked")
ActionFuture<BulkResponse> responseFuture = mock(ActionFuture.class);
when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture);
when(client.threadPool()).thenReturn(threadPool);
} }
private static class DelegateStubDataExtractor { private static class DelegateStubDataExtractor {