[7.x][ML] Persist progress when setting DFA task to failed (#61782) (#61792)

When an error occurs and we set the task to failed via
the `DataFrameAnalyticsTask.setFailed` method we do not
persist progress. If the job is later restarted, this means
we do not correctly restore from where we can but instead
we start the job from scratch and have to redo the reindexing
phase.

This commit solves this bug by persisting the progress before
setting the task to failed.

Backport of #61782
This commit is contained in:
Dimitris Athanasiou 2020-09-01 18:33:07 +03:00 committed by GitHub
parent d52ee17054
commit 2547cfbe54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 102 additions and 62 deletions

View File

@ -232,7 +232,7 @@ public class DataFrameAnalyticsManager {
Exception reindexError = getReindexError(task.getParams().getId(), reindexResponse);
if (reindexError != null) {
task.markAsFailed(reindexError);
task.setFailed(reindexError);
return;
}

View File

@ -8,7 +8,6 @@ package org.elasticsearch.xpack.ml.dataframe;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
@ -38,7 +37,6 @@ import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.tasks.TaskResult;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
@ -216,23 +214,25 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
error);
return;
}
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
DataFrameAnalyticsTaskState newTaskState =
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
updatePersistentTaskState(
newTaskState,
ActionListener.wrap(
updatedTask -> {
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
persistProgress(client, taskParams.getId(), () -> {
LOGGER.error(new ParameterizedMessage("[{}] Setting task to failed", taskParams.getId()), error);
String reason = ExceptionsHelper.unwrapCause(error).getMessage();
DataFrameAnalyticsTaskState newTaskState =
new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.FAILED, getAllocationId(), reason);
updatePersistentTaskState(
newTaskState,
ActionListener.wrap(
updatedTask -> {
String message = Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_UPDATED_STATE_WITH_REASON,
DataFrameAnalyticsState.FAILED, reason);
auditor.info(getParams().getId(), message);
LOGGER.info("[{}] {}", getParams().getId(), message);
},
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
)
);
auditor.info(getParams().getId(), message);
LOGGER.info("[{}] {}", getParams().getId(), message);
},
e -> LOGGER.error(new ParameterizedMessage("[{}] Could not update task state to [{}] with reason [{}]",
getParams().getId(), DataFrameAnalyticsState.FAILED, reason), e)
)
);
});
}
public void updateReindexTaskProgress(ActionListener<Void> listener) {
@ -285,13 +285,12 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
}
// Visible for testing
static void persistProgress(Client client, String jobId, Runnable runnable) {
void persistProgress(Client client, String jobId, Runnable runnable) {
LOGGER.debug("[{}] Persisting progress", jobId);
String progressDocId = StoredProgress.documentId(jobId);
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();
// Step 4: Run the runnable provided as the argument
// Step 3: Run the runnable provided as the argument
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
indexResponse -> {
LOGGER.debug("[{}] Successfully indexed progress document", jobId);
@ -304,7 +303,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
}
);
// Step 3: Create or update the progress document:
// Step 2: Create or update the progress document:
// - if the document did not exist, create the new one in the current write index
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
@ -317,8 +316,10 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
.id(progressDocId)
.setRequireAlias(AnomalyDetectorsIndex.jobStateIndexWriteAlias().equals(indexOrAlias))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<PhaseProgress> progress = statsHolder.getProgressTracker().report();
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
LOGGER.debug("[{}] Persisting progress is: {}", jobId, progress);
new StoredProgress(progress).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
indexRequest.source(jsonBuilder);
}
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
@ -330,28 +331,14 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
}
);
// Step 2: Search for existing progress document in .ml-state*
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
statsResponse -> {
stats.set(statsResponse.getResponse().results().get(0));
SearchRequest searchRequest =
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
.source(
new SearchSourceBuilder()
.size(1)
.query(new IdsQueryBuilder().addIds(progressDocId)));
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving stats", jobId), e);
runnable.run();
}
);
// Step 1: Fetch progress to be persisted
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(jobId);
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
// Step 1: Search for existing progress document in .ml-state*
SearchRequest searchRequest =
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
.source(
new SearchSourceBuilder()
.size(1)
.query(new IdsQueryBuilder().addIds(progressDocId)));
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
}
/**

View File

@ -16,6 +16,10 @@ import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.persistent.PersistentTasksService;
import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction;
import org.elasticsearch.search.SearchHit;
@ -23,11 +27,10 @@ import org.elasticsearch.search.SearchHits;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
@ -36,6 +39,7 @@ import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -125,14 +129,25 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
assertThat(startingState, equalTo(StartingState.FINISHED));
}
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) {
private void testPersistProgress(SearchHits searchHits, String expectedIndexOrAlias) throws IOException {
Client client = mock(Client.class);
when(client.settings()).thenReturn(Settings.EMPTY);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(client.threadPool()).thenReturn(threadPool);
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse(1);
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
ClusterService clusterService = mock(ClusterService.class);
DataFrameAnalyticsManager analyticsManager = mock(DataFrameAnalyticsManager.class);
DataFrameAnalyticsAuditor auditor = mock(DataFrameAnalyticsAuditor.class);
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, threadPool, client);
List<PhaseProgress> progress = Arrays.asList(
new PhaseProgress(ProgressTracker.REINDEXING, 100),
new PhaseProgress(ProgressTracker.LOADING_DATA, 50),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0));
StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(
"task_id", Version.CURRENT, progress, false);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(searchHits);
@ -141,14 +156,20 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
TaskManager taskManager = mock(TaskManager.class);
Runnable runnable = mock(Runnable.class);
DataFrameAnalyticsTask.persistProgress(client, "task_id", runnable);
DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
123, "type", "action", null, Collections.emptyMap(), client, clusterService, analyticsManager, auditor, taskParams);
task.init(persistentTasksService, taskManager, "task-id", 42);
task.persistProgress(client, "task_id", runnable);
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
InOrder inOrder = inOrder(client, runnable);
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
inOrder.verify(runnable).run();
@ -157,27 +178,33 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
try (XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
assertThat(parsedProgress.get(), equalTo(progress));
}
}
public void testPersistProgress_ProgressDocumentCreated() {
public void testPersistProgress_ProgressDocumentCreated() throws IOException {
testPersistProgress(SearchHits.empty(), ".ml-state-write");
}
public void testPersistProgress_ProgressDocumentUpdated() {
public void testPersistProgress_ProgressDocumentUpdated() throws IOException {
testPersistProgress(
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f),
".ml-state-dummy");
}
public void testSetFailed() {
public void testSetFailed() throws IOException {
testSetFailed(false);
}
public void testSetFailedDuringNodeShutdown() {
public void testSetFailedDuringNodeShutdown() throws IOException {
testSetFailed(true);
}
private void testSetFailed(boolean nodeShuttingDown) {
private void testSetFailed(boolean nodeShuttingDown) throws IOException {
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
Client client = mock(Client.class);
@ -190,15 +217,25 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
PersistentTasksService persistentTasksService = new PersistentTasksService(clusterService, mock(ThreadPool.class), client);
TaskManager taskManager = mock(TaskManager.class);
List<PhaseProgress> progress = Arrays.asList(
new PhaseProgress(ProgressTracker.REINDEXING, 100),
new PhaseProgress(ProgressTracker.LOADING_DATA, 100),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 30));
StartDataFrameAnalyticsAction.TaskParams taskParams =
new StartDataFrameAnalyticsAction.TaskParams(
"job-id",
Version.CURRENT,
Arrays.asList(
new PhaseProgress(ProgressTracker.REINDEXING, 0),
new PhaseProgress(ProgressTracker.LOADING_DATA, 0),
new PhaseProgress(ProgressTracker.WRITING_RESULTS, 0)),
progress,
false);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(SearchHits.empty());
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
123, "type", "action", null, Collections.emptyMap(), client, clusterService, analyticsManager, auditor, taskParams);
@ -210,7 +247,23 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
verify(analyticsManager).isNodeShuttingDown();
verify(client, atLeastOnce()).settings();
verify(client, atLeastOnce()).threadPool();
if (nodeShuttingDown == false) {
// Verify progress was persisted
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(AnomalyDetectorsIndex.jobStateIndexWriteAlias()));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-job-id-progress"));
try (XContentParser parser = JsonXContent.jsonXContent.createParser(
NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, indexRequest.source().utf8ToString())) {
StoredProgress parsedProgress = StoredProgress.PARSER.apply(parser, null);
assertThat(parsedProgress.get(), equalTo(progress));
}
verify(client).execute(
same(UpdatePersistentTaskStatusAction.INSTANCE),
eq(new UpdatePersistentTaskStatusAction.Request(