Assert that inference model has been persisted (#48332) (#48453)

This commit is contained in:
Przemysław Witek 2019-10-24 14:18:43 +02:00 committed by GitHub
parent 4d0fb6e551
commit 149537a165
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 22 deletions

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.ml.integration; package org.elasticsearch.xpack.ml.integration;
import com.google.common.collect.Ordering; import com.google.common.collect.Ordering;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkRequestBuilder;
@ -40,7 +39,6 @@ import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.Matchers.startsWith;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/48337")
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
private static final String BOOLEAN_FIELD = "boolean-field"; private static final String BOOLEAN_FIELD = "boolean-field";
@ -90,6 +88,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]", "Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be", "Estimated memory usage for this analytics to be",
@ -97,8 +96,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Started analytics", "Started analytics",
"Creating destination index [" + destIndex + "]", "Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis", "Finished analysis");
"Stored trained model with id");
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
@ -129,6 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]", "Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be", "Estimated memory usage for this analytics to be",
@ -136,8 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Started analytics", "Started analytics",
"Creating destination index [" + destIndex + "]", "Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis", "Finished analysis");
"Stored trained model with id");
} }
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
@ -184,8 +182,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
} }
assertThat(trainingRowsCount, greaterThan(0)); assertThat(trainingRowsCount, greaterThan(0));
assertThat(nonTrainingRowsCount, greaterThan(0)); assertThat(nonTrainingRowsCount, greaterThan(0));
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]", "Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be", "Estimated memory usage for this analytics to be",
@ -193,8 +193,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Started analytics", "Started analytics",
"Creating destination index [" + destIndex + "]", "Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis", "Finished analysis");
"Stored trained model with id");
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception { public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
@ -254,6 +253,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]", "Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be", "Estimated memory usage for this analytics to be",
@ -261,8 +261,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Started analytics", "Started analytics",
"Creating destination index [" + destIndex + "]", "Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis", "Finished analysis");
"Stored trained model with id");
} }
public void testDependentVariableCardinalityTooHighError() { public void testDependentVariableCardinalityTooHighError() {
@ -369,7 +368,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
classNames.add((String) topClass.get("class_name")); classNames.add((String) topClass.get("class_name"));
classProbabilities.add((Double) topClass.get("class_probability")); classProbabilities.add((Double) topClass.get("class_probability"));
} }
// Assert that all the predicted class names come from the set of keyword field values. // Assert that all the predicted class names come from the set of dependent variable values.
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues)))); classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
// Assert that the first class listed in top classes is the same as the predicted class. // Assert that the first class listed in top classes is the same as the predicted class.
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction"))); assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));

View File

@ -28,6 +28,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.notifications.AuditorField; import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
@ -42,9 +44,9 @@ import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.nullValue;
@ -173,12 +175,20 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults)); assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults));
} }
protected SearchResponse searchStoredProgress(String id) { protected SearchResponse searchStoredProgress(String jobId) {
String docId = DataFrameAnalyticsTask.progressDocId(jobId);
return client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) return client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(DataFrameAnalyticsTask.progressDocId(id))) .setQuery(QueryBuilders.idsQuery().addIds(docId))
.get(); .get();
} }
protected void assertInferenceModelPersisted(String jobId) {
SearchResponse searchResponse = client().prepareSearch(InferenceIndexConstants.LATEST_INDEX_NAME)
.setQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), jobId)))
.get();
assertThat(searchResponse.getHits().getHits(), arrayWithSize(1));
}
/** /**
* Asserts whether the audit messages fetched from index match provided prefixes. * Asserts whether the audit messages fetched from index match provided prefixes.
* More specifically, in order to pass: * More specifically, in order to pass:
@ -194,9 +204,10 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Matcher<String>[] itemMatchers = Arrays.stream(expectedAuditMessagePrefixes).map(Matchers::startsWith).toArray(Matcher[]::new); Matcher<String>[] itemMatchers = Arrays.stream(expectedAuditMessagePrefixes).map(Matchers::startsWith).toArray(Matcher[]::new);
assertBusy(() -> { assertBusy(() -> {
final List<String> allAuditMessages = fetchAllAuditMessages(configId); List<String> allAuditMessages = fetchAllAuditMessages(configId);
assertThat(allAuditMessages, hasItems(itemMatchers)); assertThat(allAuditMessages, hasItems(itemMatchers));
assertThat("Messages: " + allAuditMessages, allAuditMessages, hasSize(expectedAuditMessagePrefixes.length)); // TODO: Consider restoring this assertion when we are sure all the audit messages are available at this point.
// assertThat("Messages: " + allAuditMessages, allAuditMessages, hasSize(expectedAuditMessagePrefixes.length));
}); });
} }

View File

@ -107,6 +107,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId); assertModelStatePersisted(jobId);
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]", "Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be", "Estimated memory usage for this analytics to be",
@ -114,8 +115,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Started analytics", "Started analytics",
"Creating destination index [" + destIndex + "]", "Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis", "Finished analysis");
"Stored trained model with id");
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
@ -160,6 +160,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId); assertModelStatePersisted(jobId);
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]", "Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be", "Estimated memory usage for this analytics to be",
@ -167,8 +168,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Started analytics", "Started analytics",
"Creating destination index [" + destIndex + "]", "Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis", "Finished analysis");
"Stored trained model with id");
} }
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception { public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
@ -228,6 +228,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId); assertModelStatePersisted(jobId);
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId, assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]", "Created analytics with analysis type [regression]",
"Estimated memory usage for this analytics to be", "Estimated memory usage for this analytics to be",
@ -235,8 +236,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Started analytics", "Started analytics",
"Creating destination index [" + destIndex + "]", "Creating destination index [" + destIndex + "]",
"Finished reindexing to destination index [" + destIndex + "]", "Finished reindexing to destination index [" + destIndex + "]",
"Finished analysis", "Finished analysis");
"Stored trained model with id");
} }
public void testStopAndRestart() throws Exception { public void testStopAndRestart() throws Exception {
@ -300,6 +300,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertProgress(jobId, 100, 100, 100, 100); assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId); assertModelStatePersisted(jobId);
assertInferenceModelPersisted(jobId);
} }
private void initialize(String jobId) { private void initialize(String jobId) {