diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 6b2e2730b03..a0e0c0b4cca 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -19,12 +19,14 @@ import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; @@ -44,6 +46,7 @@ import java.util.Set; import static java.util.stream.Collectors.toList; import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -243,6 +246,64 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "boolean"); } + public void testStopAndRestart() throws Exception { + initialize("classification_stop_and_restart"); + String predictedClassField = KEYWORD_FIELD + "_prediction"; + indexData(sourceIndex, 350, 0, KEYWORD_FIELD); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); + registerAnalytics(config); + putAnalytics(config); + + assertIsStopped(jobId); + assertProgress(jobId, 0, 0, 0, 0); + + startAnalytics(jobId); + + // Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED. + assertBusy(() -> { + DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState(); + assertThat( + state, + is(anyOf( + equalTo(DataFrameAnalyticsState.REINDEXING), + equalTo(DataFrameAnalyticsState.ANALYZING), + equalTo(DataFrameAnalyticsState.STOPPED)))); + }); + stopAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + // Now let's start it again + try { + startAnalytics(jobId); + } catch (Exception e) { + if (e.getMessage().equals("Cannot start because the job has already finished")) { + // That means the job had managed to complete + } else { + throw e; + } + } + + waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1)); + + SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); + for (SearchHit hit : sourceData.getHits()) { + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getFieldValue(destDoc, "ml"); + assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); + assertThat(getFieldValue(resultsObject, "is_training"), is(true)); + assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + } + + assertProgress(jobId, 100, 100, 100, 100); + assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); + assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(predictedClassField, "keyword"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); + + } + public void testDependentVariableCardinalityTooHighError() throws Exception { initialize("cardinality_too_high"); indexData(sourceIndex, 6, 5, KEYWORD_FIELD);