[7.x] [ML] Persist/restore state for DFA classification (#50040) (#50147)

This commit adds state persist/restore for data frame analytics classification jobs.

Backport of #50040
This commit is contained in:
Dimitris Athanasiou 2019-12-13 10:33:19 +02:00 committed by GitHub
parent 1c3ce110bd
commit e6cbcf7f7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 14 deletions

View File

@ -253,12 +253,12 @@ public class Classification implements DataFrameAnalysis {
@Override
public boolean persistsState() {
return false;
return true;
}
@Override
public String getStateDocId(String jobId) {
throw new UnsupportedOperationException();
return jobId + "_classification_state#1";
}
@Override

View File

@ -208,4 +208,11 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
assertThat(json, containsString("randomize_seed"));
}
}
public void testGetStateDocId() {
Classification classification = createRandom();
assertThat(classification.persistsState(), is(true));
String randomId = randomAlphaOfLength(10);
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
}
}

View File

@ -95,6 +95,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
@ -135,6 +136,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
@ -195,6 +197,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
@ -447,4 +450,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
}
protected String stateDocId() {
return jobId + "_classification_state#1";
}
}

View File

@ -274,4 +274,11 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
assertThat(trainingRowsIds.isEmpty(), is(false));
return trainingRowsIds;
}
protected static void assertModelStatePersisted(String stateDocId) {
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(stateDocId))
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
}
}

View File

@ -12,14 +12,12 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.junit.After;
import java.util.Arrays;
@ -82,7 +80,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
@ -119,7 +117,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
@ -171,7 +169,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
@ -233,7 +231,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
}
@ -324,11 +322,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
return resultsObject;
}
private static void assertModelStatePersisted(String jobId) {
String docId = jobId + "_regression_state#1";
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(docId))
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
protected String stateDocId() {
return jobId + "_regression_state#1";
}
}