parent
4d0fb6e551
commit
149537a165
|
@ -6,7 +6,6 @@
|
||||||
package org.elasticsearch.xpack.ml.integration;
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
import com.google.common.collect.Ordering;
|
import com.google.common.collect.Ordering;
|
||||||
import org.apache.lucene.util.LuceneTestCase;
|
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
|
||||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||||
|
@ -40,7 +39,6 @@ import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||||
import static org.hamcrest.Matchers.startsWith;
|
import static org.hamcrest.Matchers.startsWith;
|
||||||
|
|
||||||
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/48337")
|
|
||||||
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
private static final String BOOLEAN_FIELD = "boolean-field";
|
private static final String BOOLEAN_FIELD = "boolean-field";
|
||||||
|
@ -90,6 +88,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [classification]",
|
"Created analytics with analysis type [classification]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -97,8 +96,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Started analytics",
|
"Started analytics",
|
||||||
"Creating destination index [" + destIndex + "]",
|
"Creating destination index [" + destIndex + "]",
|
||||||
"Finished reindexing to destination index [" + destIndex + "]",
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
"Finished analysis",
|
"Finished analysis");
|
||||||
"Stored trained model with id");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||||
|
@ -129,6 +127,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [classification]",
|
"Created analytics with analysis type [classification]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -136,8 +135,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Started analytics",
|
"Started analytics",
|
||||||
"Creating destination index [" + destIndex + "]",
|
"Creating destination index [" + destIndex + "]",
|
||||||
"Finished reindexing to destination index [" + destIndex + "]",
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
"Finished analysis",
|
"Finished analysis");
|
||||||
"Stored trained model with id");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
|
||||||
|
@ -184,8 +182,10 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
}
|
}
|
||||||
assertThat(trainingRowsCount, greaterThan(0));
|
assertThat(trainingRowsCount, greaterThan(0));
|
||||||
assertThat(nonTrainingRowsCount, greaterThan(0));
|
assertThat(nonTrainingRowsCount, greaterThan(0));
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [classification]",
|
"Created analytics with analysis type [classification]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -193,8 +193,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Started analytics",
|
"Started analytics",
|
||||||
"Creating destination index [" + destIndex + "]",
|
"Creating destination index [" + destIndex + "]",
|
||||||
"Finished reindexing to destination index [" + destIndex + "]",
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
"Finished analysis",
|
"Finished analysis");
|
||||||
"Stored trained model with id");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception {
|
||||||
|
@ -254,6 +253,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
|
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [classification]",
|
"Created analytics with analysis type [classification]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -261,8 +261,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Started analytics",
|
"Started analytics",
|
||||||
"Creating destination index [" + destIndex + "]",
|
"Creating destination index [" + destIndex + "]",
|
||||||
"Finished reindexing to destination index [" + destIndex + "]",
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
"Finished analysis",
|
"Finished analysis");
|
||||||
"Stored trained model with id");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testDependentVariableCardinalityTooHighError() {
|
public void testDependentVariableCardinalityTooHighError() {
|
||||||
|
@ -369,7 +368,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
classNames.add((String) topClass.get("class_name"));
|
classNames.add((String) topClass.get("class_name"));
|
||||||
classProbabilities.add((Double) topClass.get("class_probability"));
|
classProbabilities.add((Double) topClass.get("class_probability"));
|
||||||
}
|
}
|
||||||
// Assert that all the predicted class names come from the set of keyword field values.
|
// Assert that all the predicted class names come from the set of dependent variable values.
|
||||||
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
|
classNames.forEach(className -> assertThat(parser.apply(className), is(in(dependentVariableValues))));
|
||||||
// Assert that the first class listed in top classes is the same as the predicted class.
|
// Assert that the first class listed in top classes is the same as the predicted class.
|
||||||
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
|
assertThat(classNames.get(0), equalTo(resultsObject.get(dependentVariable + "_prediction")));
|
||||||
|
|
|
@ -28,6 +28,8 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
|
||||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||||
import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
|
import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
|
||||||
|
@ -42,9 +44,9 @@ import java.util.concurrent.TimeUnit;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static org.hamcrest.Matchers.anyOf;
|
import static org.hamcrest.Matchers.anyOf;
|
||||||
|
import static org.hamcrest.Matchers.arrayWithSize;
|
||||||
import static org.hamcrest.Matchers.equalTo;
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.hasItems;
|
import static org.hamcrest.Matchers.hasItems;
|
||||||
import static org.hamcrest.Matchers.hasSize;
|
|
||||||
import static org.hamcrest.Matchers.is;
|
import static org.hamcrest.Matchers.is;
|
||||||
import static org.hamcrest.Matchers.nullValue;
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
|
||||||
|
@ -173,12 +175,20 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
||||||
assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults));
|
assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected SearchResponse searchStoredProgress(String id) {
|
protected SearchResponse searchStoredProgress(String jobId) {
|
||||||
|
String docId = DataFrameAnalyticsTask.progressDocId(jobId);
|
||||||
return client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
|
return client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
|
||||||
.setQuery(QueryBuilders.idsQuery().addIds(DataFrameAnalyticsTask.progressDocId(id)))
|
.setQuery(QueryBuilders.idsQuery().addIds(docId))
|
||||||
.get();
|
.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void assertInferenceModelPersisted(String jobId) {
|
||||||
|
SearchResponse searchResponse = client().prepareSearch(InferenceIndexConstants.LATEST_INDEX_NAME)
|
||||||
|
.setQuery(QueryBuilders.boolQuery().filter(QueryBuilders.termQuery(TrainedModelConfig.TAGS.getPreferredName(), jobId)))
|
||||||
|
.get();
|
||||||
|
assertThat(searchResponse.getHits().getHits(), arrayWithSize(1));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Asserts whether the audit messages fetched from index match provided prefixes.
|
* Asserts whether the audit messages fetched from index match provided prefixes.
|
||||||
* More specifically, in order to pass:
|
* More specifically, in order to pass:
|
||||||
|
@ -194,9 +204,10 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
Matcher<String>[] itemMatchers = Arrays.stream(expectedAuditMessagePrefixes).map(Matchers::startsWith).toArray(Matcher[]::new);
|
Matcher<String>[] itemMatchers = Arrays.stream(expectedAuditMessagePrefixes).map(Matchers::startsWith).toArray(Matcher[]::new);
|
||||||
assertBusy(() -> {
|
assertBusy(() -> {
|
||||||
final List<String> allAuditMessages = fetchAllAuditMessages(configId);
|
List<String> allAuditMessages = fetchAllAuditMessages(configId);
|
||||||
assertThat(allAuditMessages, hasItems(itemMatchers));
|
assertThat(allAuditMessages, hasItems(itemMatchers));
|
||||||
assertThat("Messages: " + allAuditMessages, allAuditMessages, hasSize(expectedAuditMessagePrefixes.length));
|
// TODO: Consider restoring this assertion when we are sure all the audit messages are available at this point.
|
||||||
|
// assertThat("Messages: " + allAuditMessages, allAuditMessages, hasSize(expectedAuditMessagePrefixes.length));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -107,6 +107,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(jobId);
|
assertModelStatePersisted(jobId);
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [regression]",
|
"Created analytics with analysis type [regression]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -114,8 +115,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Started analytics",
|
"Started analytics",
|
||||||
"Creating destination index [" + destIndex + "]",
|
"Creating destination index [" + destIndex + "]",
|
||||||
"Finished reindexing to destination index [" + destIndex + "]",
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
"Finished analysis",
|
"Finished analysis");
|
||||||
"Stored trained model with id");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||||
|
@ -160,6 +160,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(jobId);
|
assertModelStatePersisted(jobId);
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [regression]",
|
"Created analytics with analysis type [regression]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -167,8 +168,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Started analytics",
|
"Started analytics",
|
||||||
"Creating destination index [" + destIndex + "]",
|
"Creating destination index [" + destIndex + "]",
|
||||||
"Finished reindexing to destination index [" + destIndex + "]",
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
"Finished analysis",
|
"Finished analysis");
|
||||||
"Stored trained model with id");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||||
|
@ -228,6 +228,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(jobId);
|
assertModelStatePersisted(jobId);
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
assertThatAuditMessagesMatch(jobId,
|
assertThatAuditMessagesMatch(jobId,
|
||||||
"Created analytics with analysis type [regression]",
|
"Created analytics with analysis type [regression]",
|
||||||
"Estimated memory usage for this analytics to be",
|
"Estimated memory usage for this analytics to be",
|
||||||
|
@ -235,8 +236,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
"Started analytics",
|
"Started analytics",
|
||||||
"Creating destination index [" + destIndex + "]",
|
"Creating destination index [" + destIndex + "]",
|
||||||
"Finished reindexing to destination index [" + destIndex + "]",
|
"Finished reindexing to destination index [" + destIndex + "]",
|
||||||
"Finished analysis",
|
"Finished analysis");
|
||||||
"Stored trained model with id");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testStopAndRestart() throws Exception {
|
public void testStopAndRestart() throws Exception {
|
||||||
|
@ -300,6 +300,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||||
assertProgress(jobId, 100, 100, 100, 100);
|
assertProgress(jobId, 100, 100, 100, 100);
|
||||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||||
assertModelStatePersisted(jobId);
|
assertModelStatePersisted(jobId);
|
||||||
|
assertInferenceModelPersisted(jobId);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initialize(String jobId) {
|
private void initialize(String jobId) {
|
||||||
|
|
Loading…
Reference in New Issue