This commit is contained in:
parent
13343b15c9
commit
b1a526d5e9
|
@ -19,8 +19,7 @@ import java.util.stream.IntStream;
|
||||||
|
|
||||||
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||||
|
|
||||||
@Override
|
public static Response randomResponse() {
|
||||||
protected Response createTestInstance() {
|
|
||||||
int listSize = randomInt(10);
|
int listSize = randomInt(10);
|
||||||
List<Response.Stats> analytics = new ArrayList<>(listSize);
|
List<Response.Stats> analytics = new ArrayList<>(listSize);
|
||||||
for (int j = 0; j < listSize; j++) {
|
for (int j = 0; j < listSize; j++) {
|
||||||
|
@ -36,6 +35,11 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
|
||||||
return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
|
return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Response createTestInstance() {
|
||||||
|
return randomResponse();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Writeable.Reader<Response> instanceReader() {
|
protected Writeable.Reader<Response> instanceReader() {
|
||||||
return Response::new;
|
return Response::new;
|
||||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.dataframe;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
|
import org.apache.lucene.util.SetOnce;
|
||||||
import org.elasticsearch.ResourceNotFoundException;
|
import org.elasticsearch.ResourceNotFoundException;
|
||||||
import org.elasticsearch.action.ActionListener;
|
import org.elasticsearch.action.ActionListener;
|
||||||
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
|
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
|
||||||
|
@ -15,6 +16,10 @@ import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRespo
|
||||||
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
|
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
|
||||||
import org.elasticsearch.action.index.IndexAction;
|
import org.elasticsearch.action.index.IndexAction;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
|
import org.elasticsearch.action.index.IndexResponse;
|
||||||
|
import org.elasticsearch.action.search.SearchAction;
|
||||||
|
import org.elasticsearch.action.search.SearchRequest;
|
||||||
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.client.Client;
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.cluster.service.ClusterService;
|
import org.elasticsearch.cluster.service.ClusterService;
|
||||||
|
@ -22,8 +27,10 @@ import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.unit.TimeValue;
|
import org.elasticsearch.common.unit.TimeValue;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||||
|
import org.elasticsearch.index.query.IdsQueryBuilder;
|
||||||
import org.elasticsearch.index.reindex.BulkByScrollTask;
|
import org.elasticsearch.index.reindex.BulkByScrollTask;
|
||||||
import org.elasticsearch.persistent.AllocatedPersistentTask;
|
import org.elasticsearch.persistent.AllocatedPersistentTask;
|
||||||
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||||
import org.elasticsearch.tasks.TaskId;
|
import org.elasticsearch.tasks.TaskId;
|
||||||
import org.elasticsearch.tasks.TaskResult;
|
import org.elasticsearch.tasks.TaskResult;
|
||||||
import org.elasticsearch.xpack.core.ml.MlTasks;
|
import org.elasticsearch.xpack.core.ml.MlTasks;
|
||||||
|
@ -239,35 +246,70 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
|
||||||
|
|
||||||
private void persistProgress(Runnable runnable) {
|
private void persistProgress(Runnable runnable) {
|
||||||
LOGGER.debug("[{}] Persisting progress", taskParams.getId());
|
LOGGER.debug("[{}] Persisting progress", taskParams.getId());
|
||||||
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
|
|
||||||
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, ActionListener.wrap(
|
String progressDocId = StoredProgress.documentId(taskParams.getId());
|
||||||
statsResponse -> {
|
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();
|
||||||
GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
|
|
||||||
IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.jobStateIndexWriteAlias());
|
// Step 4: Run the runnable provided as the argument
|
||||||
indexRequest.id(StoredProgress.documentId(taskParams.getId()));
|
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
|
||||||
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
indexResponse -> {
|
||||||
|
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
|
||||||
|
runnable.run();
|
||||||
|
},
|
||||||
|
indexError -> {
|
||||||
|
LOGGER.error(new ParameterizedMessage(
|
||||||
|
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
|
||||||
|
runnable.run();
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// Step 3: Create or update the progress document:
|
||||||
|
// - if the document did not exist, create the new one in the current write index
|
||||||
|
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
|
||||||
|
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
|
||||||
|
searchResponse -> {
|
||||||
|
String indexOrAlias = AnomalyDetectorsIndex.jobStateIndexWriteAlias();
|
||||||
|
if (searchResponse.getHits().getHits().length > 0) {
|
||||||
|
indexOrAlias = searchResponse.getHits().getHits()[0].getIndex();
|
||||||
|
}
|
||||||
|
IndexRequest indexRequest = new IndexRequest(indexOrAlias)
|
||||||
|
.id(progressDocId)
|
||||||
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||||
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
|
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
|
||||||
new StoredProgress(stats.getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
|
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
|
||||||
indexRequest.source(jsonBuilder);
|
indexRequest.source(jsonBuilder);
|
||||||
}
|
}
|
||||||
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap(
|
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
|
||||||
indexResponse -> {
|
|
||||||
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
|
|
||||||
runnable.run();
|
|
||||||
},
|
|
||||||
indexError -> {
|
|
||||||
LOGGER.error(new ParameterizedMessage(
|
|
||||||
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
|
|
||||||
runnable.run();
|
|
||||||
}
|
|
||||||
));
|
|
||||||
},
|
},
|
||||||
e -> {
|
e -> {
|
||||||
LOGGER.error(new ParameterizedMessage(
|
LOGGER.error(new ParameterizedMessage(
|
||||||
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
|
"[{}] cannot persist progress as an error occurred while retrieving former progress document", taskParams.getId()), e);
|
||||||
runnable.run();
|
runnable.run();
|
||||||
}
|
}
|
||||||
));
|
);
|
||||||
|
|
||||||
|
// Step 2: Search for existing progress document in .ml-state*
|
||||||
|
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
|
||||||
|
statsResponse -> {
|
||||||
|
stats.set(statsResponse.getResponse().results().get(0));
|
||||||
|
SearchRequest searchRequest =
|
||||||
|
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
|
||||||
|
.source(
|
||||||
|
new SearchSourceBuilder()
|
||||||
|
.size(1)
|
||||||
|
.query(new IdsQueryBuilder().addIds(progressDocId)));
|
||||||
|
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
|
||||||
|
},
|
||||||
|
e -> {
|
||||||
|
LOGGER.error(new ParameterizedMessage(
|
||||||
|
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
|
||||||
|
runnable.run();
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// Step 1: Fetch progress to be persisted
|
||||||
|
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
|
||||||
|
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -5,15 +5,42 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.ml.dataframe;
|
package org.elasticsearch.xpack.ml.dataframe;
|
||||||
|
|
||||||
|
import org.elasticsearch.Version;
|
||||||
|
import org.elasticsearch.action.ActionListener;
|
||||||
|
import org.elasticsearch.action.index.IndexAction;
|
||||||
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
|
import org.elasticsearch.action.index.IndexResponse;
|
||||||
|
import org.elasticsearch.action.search.SearchAction;
|
||||||
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
|
import org.elasticsearch.client.Client;
|
||||||
|
import org.elasticsearch.cluster.service.ClusterService;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.search.SearchHit;
|
||||||
|
import org.elasticsearch.search.SearchHits;
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.TaskParams;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
|
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
|
||||||
|
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
|
import org.mockito.InOrder;
|
||||||
|
import org.mockito.stubbing.Answer;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.mockito.Matchers.any;
|
||||||
|
import static org.mockito.Matchers.eq;
|
||||||
|
import static org.mockito.Mockito.doAnswer;
|
||||||
|
import static org.mockito.Mockito.inOrder;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
public class DataFrameAnalyticsTaskTests extends ESTestCase {
|
public class DataFrameAnalyticsTaskTests extends ESTestCase {
|
||||||
|
|
||||||
|
@ -87,4 +114,67 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
|
||||||
StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", Collections.emptyList());
|
StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", Collections.emptyList());
|
||||||
assertThat(startingState, equalTo(StartingState.FINISHED));
|
assertThat(startingState, equalTo(StartingState.FINISHED));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAlias) {
|
||||||
|
Client client = mock(Client.class);
|
||||||
|
ThreadPool threadPool = mock(ThreadPool.class);
|
||||||
|
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||||
|
when(client.threadPool()).thenReturn(threadPool);
|
||||||
|
|
||||||
|
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse();
|
||||||
|
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
|
||||||
|
|
||||||
|
SearchResponse searchResponse = mock(SearchResponse.class);
|
||||||
|
when(searchResponse.getHits()).thenReturn(searchHits);
|
||||||
|
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
|
||||||
|
|
||||||
|
IndexResponse indexResponse = mock(IndexResponse.class);
|
||||||
|
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
|
||||||
|
|
||||||
|
TaskParams taskParams = new TaskParams("task_id", Version.CURRENT, Collections.emptyList(), false);
|
||||||
|
DataFrameAnalyticsTask task =
|
||||||
|
new DataFrameAnalyticsTask(
|
||||||
|
0,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
client,
|
||||||
|
mock(ClusterService.class),
|
||||||
|
mock(DataFrameAnalyticsManager.class),
|
||||||
|
mock(DataFrameAnalyticsAuditor.class),
|
||||||
|
taskParams);
|
||||||
|
task.markAsCompleted();
|
||||||
|
|
||||||
|
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
|
||||||
|
|
||||||
|
InOrder inOrder = inOrder(client);
|
||||||
|
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
|
||||||
|
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
|
||||||
|
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
|
||||||
|
inOrder.verifyNoMoreInteractions();
|
||||||
|
|
||||||
|
IndexRequest indexRequest = indexRequestCaptor.getValue();
|
||||||
|
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
|
||||||
|
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMarkAsCompleted_ProgressDocumentCreated() {
|
||||||
|
testMarkAsCompleted(SearchHits.empty(), ".ml-state-write");
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMarkAsCompleted_ProgressDocumentUpdated() {
|
||||||
|
testMarkAsCompleted(
|
||||||
|
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f),
|
||||||
|
".ml-state-dummy");
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private static <Response> Answer<Response> withResponse(Response response) {
|
||||||
|
return invocationOnMock -> {
|
||||||
|
ActionListener<Response> listener = (ActionListener<Response>) invocationOnMock.getArguments()[2];
|
||||||
|
listener.onResponse(response);
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue