[7.x] [ML] Update DFA progress document in the index the document belongs to (#51111) (#51117)

This commit is contained in:
Przemysław Witek 2020-01-17 08:12:54 +01:00 committed by GitHub
parent 13343b15c9
commit b1a526d5e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 159 additions and 23 deletions

View File

@ -19,8 +19,7 @@ import java.util.stream.IntStream;
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
@Override
protected Response createTestInstance() {
public static Response randomResponse() {
int listSize = randomInt(10);
List<Response.Stats> analytics = new ArrayList<>(listSize);
for (int j = 0; j < listSize; j++) {
@ -36,6 +35,11 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
}
@Override
protected Response createTestInstance() {
return randomResponse();
}
@Override
protected Writeable.Reader<Response> instanceReader() {
return Response::new;

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.dataframe;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
@ -15,6 +16,10 @@ import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRespo
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
@ -22,8 +27,10 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.index.reindex.BulkByScrollTask;
import org.elasticsearch.persistent.AllocatedPersistentTask;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskResult;
import org.elasticsearch.xpack.core.ml.MlTasks;
@ -239,35 +246,70 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
private void persistProgress(Runnable runnable) {
LOGGER.debug("[{}] Persisting progress", taskParams.getId());
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, ActionListener.wrap(
statsResponse -> {
GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.jobStateIndexWriteAlias());
indexRequest.id(StoredProgress.documentId(taskParams.getId()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
String progressDocId = StoredProgress.documentId(taskParams.getId());
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();
// Step 4: Run the runnable provided as the argument
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
indexResponse -> {
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
runnable.run();
},
indexError -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
runnable.run();
}
);
// Step 3: Create or update the progress document:
// - if the document did not exist, create the new one in the current write index
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
searchResponse -> {
String indexOrAlias = AnomalyDetectorsIndex.jobStateIndexWriteAlias();
if (searchResponse.getHits().getHits().length > 0) {
indexOrAlias = searchResponse.getHits().getHits()[0].getIndex();
}
IndexRequest indexRequest = new IndexRequest(indexOrAlias)
.id(progressDocId)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
new StoredProgress(stats.getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
indexRequest.source(jsonBuilder);
}
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap(
indexResponse -> {
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
runnable.run();
},
indexError -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
runnable.run();
}
));
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
"[{}] cannot persist progress as an error occurred while retrieving former progress document", taskParams.getId()), e);
runnable.run();
}
));
);
// Step 2: Search for existing progress document in .ml-state*
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
statsResponse -> {
stats.set(statsResponse.getResponse().results().get(0));
SearchRequest searchRequest =
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
.source(
new SearchSourceBuilder()
.size(1)
.query(new IdsQueryBuilder().addIds(progressDocId)));
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
},
e -> {
LOGGER.error(new ParameterizedMessage(
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
runnable.run();
}
);
// Step 1: Fetch progress to be persisted
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
}
/**

View File

@ -5,15 +5,42 @@
*/
package org.elasticsearch.xpack.ml.dataframe;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.TaskParams;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.mockito.ArgumentCaptor;
import org.mockito.InOrder;
import org.mockito.stubbing.Answer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class DataFrameAnalyticsTaskTests extends ESTestCase {
@ -87,4 +114,67 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", Collections.emptyList());
assertThat(startingState, equalTo(StartingState.FINISHED));
}
private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAlias) {
Client client = mock(Client.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(client.threadPool()).thenReturn(threadPool);
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse();
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(searchHits);
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
IndexResponse indexResponse = mock(IndexResponse.class);
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
TaskParams taskParams = new TaskParams("task_id", Version.CURRENT, Collections.emptyList(), false);
DataFrameAnalyticsTask task =
new DataFrameAnalyticsTask(
0,
"",
"",
null,
null,
client,
mock(ClusterService.class),
mock(DataFrameAnalyticsManager.class),
mock(DataFrameAnalyticsAuditor.class),
taskParams);
task.markAsCompleted();
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
InOrder inOrder = inOrder(client);
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
inOrder.verifyNoMoreInteractions();
IndexRequest indexRequest = indexRequestCaptor.getValue();
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
}
public void testMarkAsCompleted_ProgressDocumentCreated() {
testMarkAsCompleted(SearchHits.empty(), ".ml-state-write");
}
public void testMarkAsCompleted_ProgressDocumentUpdated() {
testMarkAsCompleted(
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f),
".ml-state-dummy");
}
@SuppressWarnings("unchecked")
private static <Response> Answer<Response> withResponse(Response response) {
return invocationOnMock -> {
ActionListener<Response> listener = (ActionListener<Response>) invocationOnMock.getArguments()[2];
listener.onResponse(response);
return null;
};
}
}