Employs `ResultsPersisterService` from `DataFrameRowsJoiner` in order to add retries when a data frame analytics job is persisting the results to the destination data frame. Backport of #52048
This commit is contained in:
parent
2f1631d9d0
commit
cbebc26f50
|
@ -634,7 +634,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin, Analys
|
|||
|
||||
// Data frame analytics components
|
||||
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
|
||||
dataFrameAnalyticsAuditor, trainedModelProvider);
|
||||
dataFrameAnalyticsAuditor, trainedModelProvider, resultsPersisterService);
|
||||
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
|
||||
new MemoryUsageEstimationProcessManager(
|
||||
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
|||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
@ -62,19 +63,22 @@ public class AnalyticsProcessManager {
|
|||
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
|
||||
private final DataFrameAnalyticsAuditor auditor;
|
||||
private final TrainedModelProvider trainedModelProvider;
|
||||
private final ResultsPersisterService resultsPersisterService;
|
||||
|
||||
public AnalyticsProcessManager(Client client,
|
||||
ThreadPool threadPool,
|
||||
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
||||
DataFrameAnalyticsAuditor auditor,
|
||||
TrainedModelProvider trainedModelProvider) {
|
||||
TrainedModelProvider trainedModelProvider,
|
||||
ResultsPersisterService resultsPersisterService) {
|
||||
this(
|
||||
client,
|
||||
threadPool.generic(),
|
||||
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
|
||||
analyticsProcessFactory,
|
||||
auditor,
|
||||
trainedModelProvider);
|
||||
trainedModelProvider,
|
||||
resultsPersisterService);
|
||||
}
|
||||
|
||||
// Visible for testing
|
||||
|
@ -83,13 +87,15 @@ public class AnalyticsProcessManager {
|
|||
ExecutorService executorServiceForProcess,
|
||||
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
|
||||
DataFrameAnalyticsAuditor auditor,
|
||||
TrainedModelProvider trainedModelProvider) {
|
||||
TrainedModelProvider trainedModelProvider,
|
||||
ResultsPersisterService resultsPersisterService) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
|
||||
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
|
||||
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
|
||||
this.auditor = Objects.requireNonNull(auditor);
|
||||
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
|
||||
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
|
||||
}
|
||||
|
||||
public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) {
|
||||
|
@ -419,7 +425,7 @@ public class AnalyticsProcessManager {
|
|||
private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task,
|
||||
DataFrameDataExtractorFactory dataExtractorFactory) {
|
||||
DataFrameRowsJoiner dataFrameRowsJoiner =
|
||||
new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true));
|
||||
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
|
||||
return new AnalyticsResultProcessor(
|
||||
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
|
||||
}
|
||||
|
|
|
@ -9,17 +9,14 @@ import org.apache.logging.log4j.LogManager;
|
|||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.action.DocWriteRequest;
|
||||
import org.elasticsearch.action.bulk.BulkAction;
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
@ -38,16 +35,17 @@ class DataFrameRowsJoiner implements AutoCloseable {
|
|||
private static final int RESULTS_BATCH_SIZE = 1000;
|
||||
|
||||
private final String analyticsId;
|
||||
private final Client client;
|
||||
private final DataFrameDataExtractor dataExtractor;
|
||||
private final ResultsPersisterService resultsPersisterService;
|
||||
private final Iterator<DataFrameDataExtractor.Row> dataFrameRowsIterator;
|
||||
private LinkedList<RowResults> currentResults;
|
||||
private volatile String failure;
|
||||
|
||||
DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) {
|
||||
DataFrameRowsJoiner(String analyticsId, DataFrameDataExtractor dataExtractor,
|
||||
ResultsPersisterService resultsPersisterService) {
|
||||
this.analyticsId = Objects.requireNonNull(analyticsId);
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.dataExtractor = Objects.requireNonNull(dataExtractor);
|
||||
this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
|
||||
this.dataFrameRowsIterator = new ResultMatchingDataFrameRows();
|
||||
this.currentResults = new LinkedList<>();
|
||||
}
|
||||
|
@ -88,7 +86,8 @@ class DataFrameRowsJoiner implements AutoCloseable {
|
|||
bulkRequest.add(createIndexRequest(result, row.getHit()));
|
||||
}
|
||||
if (bulkRequest.numberOfActions() > 0) {
|
||||
executeBulkRequest(bulkRequest);
|
||||
resultsPersisterService.bulkIndexWithHeadersWithRetry(
|
||||
dataExtractor.getHeaders(), bulkRequest, analyticsId, () -> true, errorMsg -> {});
|
||||
}
|
||||
currentResults = new LinkedList<>();
|
||||
}
|
||||
|
@ -113,14 +112,6 @@ class DataFrameRowsJoiner implements AutoCloseable {
|
|||
return indexRequest;
|
||||
}
|
||||
|
||||
private void executeBulkRequest(BulkRequest bulkRequest) {
|
||||
BulkResponse bulkResponse = ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), ClientHelper.ML_ORIGIN, client,
|
||||
() -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet());
|
||||
if (bulkResponse.hasFailures()) {
|
||||
throw ExceptionsHelper.serverError("failures while writing results [" + bulkResponse.buildFailureMessage() + "]");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
try {
|
||||
|
|
|
@ -9,6 +9,7 @@ import org.apache.logging.log4j.LogManager;
|
|||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.bulk.BulkAction;
|
||||
import org.elasticsearch.action.bulk.BulkItemResponse;
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
|
@ -27,13 +28,16 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
|||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.time.Duration;
|
||||
import java.util.Arrays;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
@ -95,9 +99,28 @@ public class ResultsPersisterService {
|
|||
String jobId,
|
||||
Supplier<Boolean> shouldRetry,
|
||||
Consumer<String> msgHandler) {
|
||||
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
|
||||
providedBulkRequest -> client.bulk(providedBulkRequest).actionGet());
|
||||
}
|
||||
|
||||
public BulkResponse bulkIndexWithHeadersWithRetry(Map<String, String> headers,
|
||||
BulkRequest bulkRequest,
|
||||
String jobId,
|
||||
Supplier<Boolean> shouldRetry,
|
||||
Consumer<String> msgHandler) {
|
||||
return bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, msgHandler,
|
||||
providedBulkRequest -> ClientHelper.executeWithHeaders(
|
||||
headers, ClientHelper.ML_ORIGIN, client, () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()));
|
||||
}
|
||||
|
||||
private BulkResponse bulkIndexWithRetry(BulkRequest bulkRequest,
|
||||
String jobId,
|
||||
Supplier<Boolean> shouldRetry,
|
||||
Consumer<String> msgHandler,
|
||||
Function<BulkRequest, BulkResponse> actionExecutor) {
|
||||
RetryContext retryContext = new RetryContext(jobId, shouldRetry, msgHandler);
|
||||
while (true) {
|
||||
BulkResponse bulkResponse = client.bulk(bulkRequest).actionGet();
|
||||
BulkResponse bulkResponse = actionExecutor.apply(bulkRequest);
|
||||
if (bulkResponse.hasFailures() == false) {
|
||||
return bulkResponse;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
|||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
import org.junit.Before;
|
||||
import org.mockito.InOrder;
|
||||
|
||||
|
@ -65,6 +66,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
|||
private DataFrameAnalyticsConfig dataFrameAnalyticsConfig;
|
||||
private DataFrameDataExtractorFactory dataExtractorFactory;
|
||||
private DataFrameDataExtractor dataExtractor;
|
||||
private ResultsPersisterService resultsPersisterService;
|
||||
private AnalyticsProcessManager processManager;
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
|
@ -97,8 +99,10 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
|
|||
when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor);
|
||||
when(dataExtractorFactory.getExtractedFields()).thenReturn(mock(ExtractedFields.class));
|
||||
|
||||
processManager = new AnalyticsProcessManager(
|
||||
client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider);
|
||||
resultsPersisterService = mock(ResultsPersisterService.class);
|
||||
|
||||
processManager = new AnalyticsProcessManager(client, executorServiceForJob, executorServiceForProcess, processFactory, auditor,
|
||||
trainedModelProvider, resultsPersisterService);
|
||||
}
|
||||
|
||||
public void testRunJob_TaskIsStopping() {
|
||||
|
|
|
@ -5,22 +5,17 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process;
|
||||
|
||||
import org.elasticsearch.action.ActionFuture;
|
||||
import org.elasticsearch.action.bulk.BulkAction;
|
||||
import org.elasticsearch.action.bulk.BulkItemResponse;
|
||||
import org.elasticsearch.action.bulk.BulkRequest;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.bytes.BytesArray;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.text.Text;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
|
||||
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
|
||||
import org.junit.Before;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
|
@ -35,7 +30,8 @@ import java.util.Optional;
|
|||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.Matchers.same;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
@ -46,19 +42,22 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
|
|||
|
||||
private static final String ANALYTICS_ID = "my_analytics";
|
||||
|
||||
private Client client;
|
||||
private static final Map<String, String> HEADERS = Collections.singletonMap("foo", "bar");
|
||||
|
||||
private DataFrameDataExtractor dataExtractor;
|
||||
private ResultsPersisterService resultsPersisterService;
|
||||
private ArgumentCaptor<BulkRequest> bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class);
|
||||
|
||||
@Before
|
||||
public void setUpMocks() {
|
||||
client = mock(Client.class);
|
||||
dataExtractor = mock(DataFrameDataExtractor.class);
|
||||
when(dataExtractor.getHeaders()).thenReturn(HEADERS);
|
||||
resultsPersisterService = mock(ResultsPersisterService.class);
|
||||
}
|
||||
|
||||
public void testProcess_GivenNoResults() {
|
||||
givenProcessResults(Collections.emptyList());
|
||||
verifyNoMoreInteractions(client);
|
||||
verifyNoMoreInteractions(resultsPersisterService);
|
||||
}
|
||||
|
||||
public void testProcess_GivenSingleRowAndResult() throws IOException {
|
||||
|
@ -126,7 +125,7 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
|
|||
RowResults result = new RowResults(2, resultFields);
|
||||
givenProcessResults(Arrays.asList(result));
|
||||
|
||||
verifyNoMoreInteractions(client);
|
||||
verifyNoMoreInteractions(resultsPersisterService);
|
||||
}
|
||||
|
||||
public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException {
|
||||
|
@ -204,7 +203,7 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
|
|||
RowResults result2 = new RowResults(2, resultFields);
|
||||
givenProcessResults(Arrays.asList(result1, result2));
|
||||
|
||||
verifyNoMoreInteractions(client);
|
||||
verifyNoMoreInteractions(resultsPersisterService);
|
||||
}
|
||||
|
||||
public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException {
|
||||
|
@ -218,13 +217,13 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
|
|||
|
||||
givenProcessResults(Collections.emptyList());
|
||||
|
||||
verifyNoMoreInteractions(client);
|
||||
verifyNoMoreInteractions(resultsPersisterService);
|
||||
verify(dataExtractor).cancel();
|
||||
verify(dataExtractor, times(2)).next();
|
||||
}
|
||||
|
||||
private void givenProcessResults(List<RowResults> results) {
|
||||
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor)) {
|
||||
try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, dataExtractor, resultsPersisterService)) {
|
||||
results.forEach(joiner::processRowResults);
|
||||
}
|
||||
}
|
||||
|
@ -251,14 +250,9 @@ public class DataFrameRowsJoinerTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private void givenClientHasNoFailures() {
|
||||
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
|
||||
ThreadPool threadPool = mock(ThreadPool.class);
|
||||
when(threadPool.getThreadContext()).thenReturn(threadContext);
|
||||
@SuppressWarnings("unchecked")
|
||||
ActionFuture<BulkResponse> responseFuture = mock(ActionFuture.class);
|
||||
when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
|
||||
when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture);
|
||||
when(client.threadPool()).thenReturn(threadPool);
|
||||
when(resultsPersisterService.bulkIndexWithHeadersWithRetry(
|
||||
eq(HEADERS), bulkRequestCaptor.capture(), eq(ANALYTICS_ID), any(), any()))
|
||||
.thenReturn(new BulkResponse(new BulkItemResponse[0], 0));
|
||||
}
|
||||
|
||||
private static class DelegateStubDataExtractor {
|
||||
|
|
Loading…
Reference in New Issue