This commit is contained in:
parent
13343b15c9
commit
b1a526d5e9
|
@ -19,8 +19,7 @@ import java.util.stream.IntStream;
|
|||
|
||||
public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
|
||||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
public static Response randomResponse() {
|
||||
int listSize = randomInt(10);
|
||||
List<Response.Stats> analytics = new ArrayList<>(listSize);
|
||||
for (int j = 0; j < listSize; j++) {
|
||||
|
@ -36,6 +35,11 @@ public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireS
|
|||
return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Response createTestInstance() {
|
||||
return randomResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<Response> instanceReader() {
|
||||
return Response::new;
|
||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.dataframe;
|
|||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.apache.lucene.util.SetOnce;
|
||||
import org.elasticsearch.ResourceNotFoundException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
|
||||
|
@ -15,6 +16,10 @@ import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRespo
|
|||
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest;
|
||||
import org.elasticsearch.action.index.IndexAction;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.index.IndexResponse;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
import org.elasticsearch.action.search.SearchRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
|
@ -22,8 +27,10 @@ import org.elasticsearch.common.Nullable;
|
|||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.common.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.index.query.IdsQueryBuilder;
|
||||
import org.elasticsearch.index.reindex.BulkByScrollTask;
|
||||
import org.elasticsearch.persistent.AllocatedPersistentTask;
|
||||
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
||||
import org.elasticsearch.tasks.TaskId;
|
||||
import org.elasticsearch.tasks.TaskResult;
|
||||
import org.elasticsearch.xpack.core.ml.MlTasks;
|
||||
|
@ -239,35 +246,70 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
|
|||
|
||||
private void persistProgress(Runnable runnable) {
|
||||
LOGGER.debug("[{}] Persisting progress", taskParams.getId());
|
||||
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, ActionListener.wrap(
|
||||
statsResponse -> {
|
||||
GetDataFrameAnalyticsStatsAction.Response.Stats stats = statsResponse.getResponse().results().get(0);
|
||||
IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.jobStateIndexWriteAlias());
|
||||
indexRequest.id(StoredProgress.documentId(taskParams.getId()));
|
||||
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
String progressDocId = StoredProgress.documentId(taskParams.getId());
|
||||
SetOnce<GetDataFrameAnalyticsStatsAction.Response.Stats> stats = new SetOnce<>();
|
||||
|
||||
// Step 4: Run the runnable provided as the argument
|
||||
ActionListener<IndexResponse> indexProgressDocListener = ActionListener.wrap(
|
||||
indexResponse -> {
|
||||
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
|
||||
runnable.run();
|
||||
},
|
||||
indexError -> {
|
||||
LOGGER.error(new ParameterizedMessage(
|
||||
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
|
||||
runnable.run();
|
||||
}
|
||||
);
|
||||
|
||||
// Step 3: Create or update the progress document:
|
||||
// - if the document did not exist, create the new one in the current write index
|
||||
// - if the document did exist, update it in the index where it resides (not necessarily the current write index)
|
||||
ActionListener<SearchResponse> searchFormerProgressDocListener = ActionListener.wrap(
|
||||
searchResponse -> {
|
||||
String indexOrAlias = AnomalyDetectorsIndex.jobStateIndexWriteAlias();
|
||||
if (searchResponse.getHits().getHits().length > 0) {
|
||||
indexOrAlias = searchResponse.getHits().getHits()[0].getIndex();
|
||||
}
|
||||
IndexRequest indexRequest = new IndexRequest(indexOrAlias)
|
||||
.id(progressDocId)
|
||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
try (XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) {
|
||||
new StoredProgress(stats.getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
|
||||
new StoredProgress(stats.get().getProgress()).toXContent(jsonBuilder, Payload.XContent.EMPTY_PARAMS);
|
||||
indexRequest.source(jsonBuilder);
|
||||
}
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap(
|
||||
indexResponse -> {
|
||||
LOGGER.debug("[{}] Successfully indexed progress document", taskParams.getId());
|
||||
runnable.run();
|
||||
},
|
||||
indexError -> {
|
||||
LOGGER.error(new ParameterizedMessage(
|
||||
"[{}] cannot persist progress as an error occurred while indexing", taskParams.getId()), indexError);
|
||||
runnable.run();
|
||||
}
|
||||
));
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexProgressDocListener);
|
||||
},
|
||||
e -> {
|
||||
LOGGER.error(new ParameterizedMessage(
|
||||
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
|
||||
"[{}] cannot persist progress as an error occurred while retrieving former progress document", taskParams.getId()), e);
|
||||
runnable.run();
|
||||
}
|
||||
));
|
||||
);
|
||||
|
||||
// Step 2: Search for existing progress document in .ml-state*
|
||||
ActionListener<GetDataFrameAnalyticsStatsAction.Response> getStatsListener = ActionListener.wrap(
|
||||
statsResponse -> {
|
||||
stats.set(statsResponse.getResponse().results().get(0));
|
||||
SearchRequest searchRequest =
|
||||
new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern())
|
||||
.source(
|
||||
new SearchSourceBuilder()
|
||||
.size(1)
|
||||
.query(new IdsQueryBuilder().addIds(progressDocId)));
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, searchFormerProgressDocListener);
|
||||
},
|
||||
e -> {
|
||||
LOGGER.error(new ParameterizedMessage(
|
||||
"[{}] cannot persist progress as an error occurred while retrieving stats", taskParams.getId()), e);
|
||||
runnable.run();
|
||||
}
|
||||
);
|
||||
|
||||
// Step 1: Fetch progress to be persisted
|
||||
GetDataFrameAnalyticsStatsAction.Request getStatsRequest = new GetDataFrameAnalyticsStatsAction.Request(taskParams.getId());
|
||||
executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsStatsAction.INSTANCE, getStatsRequest, getStatsListener);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -5,15 +5,42 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe;
|
||||
|
||||
import org.elasticsearch.Version;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.index.IndexAction;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.index.IndexResponse;
|
||||
import org.elasticsearch.action.search.SearchAction;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.search.SearchHits;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsActionResponseTests;
|
||||
import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.TaskParams;
|
||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.StartingState;
|
||||
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.InOrder;
|
||||
import org.mockito.stubbing.Answer;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.eq;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.inOrder;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class DataFrameAnalyticsTaskTests extends ESTestCase {
|
||||
|
||||
|
@ -87,4 +114,67 @@ public class DataFrameAnalyticsTaskTests extends ESTestCase {
|
|||
StartingState startingState = DataFrameAnalyticsTask.determineStartingState("foo", Collections.emptyList());
|
||||
assertThat(startingState, equalTo(StartingState.FINISHED));
|
||||
}
|
||||
|
||||
private void testMarkAsCompleted(SearchHits searchHits, String expectedIndexOrAlias) {
|
||||
Client client = mock(Client.class);
|
||||
ThreadPool threadPool = mock(ThreadPool.class);
|
||||
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||
when(client.threadPool()).thenReturn(threadPool);
|
||||
|
||||
GetDataFrameAnalyticsStatsAction.Response getStatsResponse = GetDataFrameAnalyticsStatsActionResponseTests.randomResponse();
|
||||
doAnswer(withResponse(getStatsResponse)).when(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
|
||||
|
||||
SearchResponse searchResponse = mock(SearchResponse.class);
|
||||
when(searchResponse.getHits()).thenReturn(searchHits);
|
||||
doAnswer(withResponse(searchResponse)).when(client).execute(eq(SearchAction.INSTANCE), any(), any());
|
||||
|
||||
IndexResponse indexResponse = mock(IndexResponse.class);
|
||||
doAnswer(withResponse(indexResponse)).when(client).execute(eq(IndexAction.INSTANCE), any(), any());
|
||||
|
||||
TaskParams taskParams = new TaskParams("task_id", Version.CURRENT, Collections.emptyList(), false);
|
||||
DataFrameAnalyticsTask task =
|
||||
new DataFrameAnalyticsTask(
|
||||
0,
|
||||
"",
|
||||
"",
|
||||
null,
|
||||
null,
|
||||
client,
|
||||
mock(ClusterService.class),
|
||||
mock(DataFrameAnalyticsManager.class),
|
||||
mock(DataFrameAnalyticsAuditor.class),
|
||||
taskParams);
|
||||
task.markAsCompleted();
|
||||
|
||||
ArgumentCaptor<IndexRequest> indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class);
|
||||
|
||||
InOrder inOrder = inOrder(client);
|
||||
inOrder.verify(client).execute(eq(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
|
||||
inOrder.verify(client).execute(eq(SearchAction.INSTANCE), any(), any());
|
||||
inOrder.verify(client).execute(eq(IndexAction.INSTANCE), indexRequestCaptor.capture(), any());
|
||||
inOrder.verifyNoMoreInteractions();
|
||||
|
||||
IndexRequest indexRequest = indexRequestCaptor.getValue();
|
||||
assertThat(indexRequest.index(), equalTo(expectedIndexOrAlias));
|
||||
assertThat(indexRequest.id(), equalTo("data_frame_analytics-task_id-progress"));
|
||||
}
|
||||
|
||||
public void testMarkAsCompleted_ProgressDocumentCreated() {
|
||||
testMarkAsCompleted(SearchHits.empty(), ".ml-state-write");
|
||||
}
|
||||
|
||||
public void testMarkAsCompleted_ProgressDocumentUpdated() {
|
||||
testMarkAsCompleted(
|
||||
new SearchHits(new SearchHit[]{ SearchHit.createFromMap(Collections.singletonMap("_index", ".ml-state-dummy")) }, null, 0.0f),
|
||||
".ml-state-dummy");
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private static <Response> Answer<Response> withResponse(Response response) {
|
||||
return invocationOnMock -> {
|
||||
ActionListener<Response> listener = (ActionListener<Response>) invocationOnMock.getArguments()[2];
|
||||
listener.onResponse(response);
|
||||
return null;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue