mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-23 05:15:04 +00:00
This commit adds state persist/restore for data frame analytics classification jobs. Backport of #50040
This commit is contained in:
parent
1c3ce110bd
commit
e6cbcf7f7c
@ -253,12 +253,12 @@ public class Classification implements DataFrameAnalysis {
|
||||
|
||||
@Override
|
||||
public boolean persistsState() {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStateDocId(String jobId) {
|
||||
throw new UnsupportedOperationException();
|
||||
return jobId + "_classification_state#1";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -208,4 +208,11 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
|
||||
assertThat(json, containsString("randomize_seed"));
|
||||
}
|
||||
}
|
||||
|
||||
public void testGetStateDocId() {
|
||||
Classification classification = createRandom();
|
||||
assertThat(classification.persistsState(), is(true));
|
||||
String randomId = randomAlphaOfLength(10);
|
||||
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
|
||||
}
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
@ -135,6 +136,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
@ -195,6 +197,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [classification]",
|
||||
@ -447,4 +450,8 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
|
||||
}
|
||||
}
|
||||
|
||||
protected String stateDocId() {
|
||||
return jobId + "_classification_state#1";
|
||||
}
|
||||
}
|
||||
|
@ -274,4 +274,11 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
||||
assertThat(trainingRowsIds.isEmpty(), is(false));
|
||||
return trainingRowsIds;
|
||||
}
|
||||
|
||||
protected static void assertModelStatePersisted(String stateDocId) {
|
||||
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
|
||||
.setQuery(QueryBuilders.idsQuery().addIds(stateDocId))
|
||||
.get();
|
||||
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
|
||||
}
|
||||
}
|
||||
|
@ -12,14 +12,12 @@ import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.common.unit.TimeValue;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.Arrays;
|
||||
@ -82,7 +80,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(jobId);
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
@ -119,7 +117,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(jobId);
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
@ -171,7 +169,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(jobId);
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
assertThatAuditMessagesMatch(jobId,
|
||||
"Created analytics with analysis type [regression]",
|
||||
@ -233,7 +231,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(jobId);
|
||||
assertModelStatePersisted(stateDocId());
|
||||
assertInferenceModelPersisted(jobId);
|
||||
}
|
||||
|
||||
@ -324,11 +322,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
return resultsObject;
|
||||
}
|
||||
|
||||
private static void assertModelStatePersisted(String jobId) {
|
||||
String docId = jobId + "_regression_state#1";
|
||||
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
|
||||
.setQuery(QueryBuilders.idsQuery().addIds(docId))
|
||||
.get();
|
||||
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
|
||||
protected String stateDocId() {
|
||||
return jobId + "_regression_state#1";
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user