parent
35453e2b0e
commit
4116452d90
|
@ -19,12 +19,14 @@ import org.elasticsearch.action.index.IndexAction;
|
|||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
|
@ -44,6 +46,7 @@ import java.util.Set;
|
|||
import static java.util.stream.Collectors.toList;
|
||||
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||
|
@ -243,6 +246,64 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "boolean");
|
||||
}
|
||||
|
||||
public void testStopAndRestart() throws Exception {
|
||||
initialize("classification_stop_and_restart");
|
||||
String predictedClassField = KEYWORD_FIELD + "_prediction";
|
||||
indexData(sourceIndex, 350, 0, KEYWORD_FIELD);
|
||||
|
||||
DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertIsStopped(jobId);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
|
||||
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
|
||||
assertBusy(() -> {
|
||||
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
|
||||
assertThat(
|
||||
state,
|
||||
is(anyOf(
|
||||
equalTo(DataFrameAnalyticsState.REINDEXING),
|
||||
equalTo(DataFrameAnalyticsState.ANALYZING),
|
||||
equalTo(DataFrameAnalyticsState.STOPPED))));
|
||||
});
|
||||
stopAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
// Now let's start it again
|
||||
try {
|
||||
startAnalytics(jobId);
|
||||
} catch (Exception e) {
|
||||
if (e.getMessage().equals("Cannot start because the job has already finished")) {
|
||||
// That means the job had managed to complete
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));
|
||||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getFieldValue(destDoc, "ml");
|
||||
assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES)));
|
||||
assertThat(getFieldValue(resultsObject, "is_training"), is(true));
|
||||
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertMlResultsFieldMappings(predictedClassField, "keyword");
|
||||
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
|
||||
|
||||
}
|
||||
|
||||
public void testDependentVariableCardinalityTooHighError() throws Exception {
|
||||
initialize("cardinality_too_high");
|
||||
indexData(sourceIndex, 6, 5, KEYWORD_FIELD);
|
||||
|
|
Loading…
Reference in New Issue