[7.x][ML] Decouple DFA progress testing from analyses phases (#55925) (#56024)

This refactors native integ tests to assert progress without
expecting explicit phases for analyses. We can test those with
yaml tests in a single place.

Backport of #55925
This commit is contained in:
Dimitris Athanasiou 2020-04-30 17:05:47 +03:00 committed by GitHub
parent 273ff6a105
commit 17b904def5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 58 additions and 48 deletions

View File

@ -103,7 +103,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
@ -121,7 +121,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(importanceArray, hasSize(greaterThan(0)));
}
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -150,7 +150,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
@ -171,7 +171,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L));
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -210,7 +210,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
@ -245,7 +245,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(300L));
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -305,7 +305,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
NodeAcknowledgedResponse response = startAnalytics(jobId);
assertThat(response.getNode(), not(emptyString()));
@ -346,7 +346,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
}
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -394,7 +394,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
}
public void testDependentVariableIsNested() throws Exception {
@ -407,7 +407,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -425,7 +425,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -443,7 +443,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -539,7 +539,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
});
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
}
public void testSetUpgradeMode_NewTaskDoesNotStart() throws Exception {
@ -572,7 +572,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);

View File

@ -67,6 +67,7 @@ import static org.elasticsearch.common.xcontent.support.XContentMapValues.extrac
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
@ -199,19 +200,28 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
assertThat("Stats were: " + Strings.toString(stats), stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED));
}
protected void assertProgress(String id, int reindexing, int loadingData, int analyzing, int writingResults) {
protected void assertProgressIsZero(String id) {
List<PhaseProgress> progress = getProgress(id);
assertThat("progress is not all zero: " + progress,
progress.stream().allMatch(phaseProgress -> phaseProgress.getProgressPercent() == 0), is(true));
}
protected void assertProgressComplete(String id) {
List<PhaseProgress> progress = getProgress(id);
assertThat("progress is complete: " + progress,
progress.stream().allMatch(phaseProgress -> phaseProgress.getProgressPercent() == 100), is(true));
}
private List<PhaseProgress> getProgress(String id) {
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id);
assertThat(stats.getId(), equalTo(id));
List<PhaseProgress> progress = stats.getProgress();
assertThat(progress, hasSize(4));
// We should have at least 4 phases: reindexing, loading_data, writing_results, plus at least one for the analysis
assertThat(progress.size(), greaterThanOrEqualTo(4));
assertThat(progress.get(0).getPhase(), equalTo("reindexing"));
assertThat(progress.get(1).getPhase(), equalTo("loading_data"));
assertThat(progress.get(2).getPhase(), equalTo("analyzing"));
assertThat(progress.get(3).getPhase(), equalTo("writing_results"));
assertThat(progress.get(0).getProgressPercent(), equalTo(reindexing));
assertThat(progress.get(1).getProgressPercent(), equalTo(loadingData));
assertThat(progress.get(2).getProgressPercent(), equalTo(analyzing));
assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults));
assertThat(progress.get(progress.size() - 1).getPhase(), equalTo("writing_results"));
return progress;
}
protected SearchResponse searchStoredProgress(String jobId) {

View File

@ -74,7 +74,7 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt
putAnalytics(config);
assertIsStopped(id);
assertProgress(id, 0, 0, 0, 0);
assertProgressIsZero(id);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -108,7 +108,7 @@ public class OutlierDetectionWithMissingFieldsIT extends MlNativeDataFrameAnalyt
}
}
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
}

View File

@ -72,7 +72,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
@ -101,7 +101,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
isPresent());
}
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -129,7 +129,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
@ -143,7 +143,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(resultsObject.get("is_training"), is(true));
}
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId);
@ -184,7 +184,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
@ -215,7 +215,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(350L));
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -243,7 +243,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
NodeAcknowledgedResponse response = startAnalytics(jobId);
assertThat(response.getNode(), not(emptyString()));
@ -284,7 +284,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
assertThat(resultsObject.get("is_training"), is(true));
}
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -342,7 +342,7 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
@ -380,11 +380,11 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
putAnalytics(config);
assertIsStopped(jobId);
assertProgress(jobId, 0, 0, 0, 0);
assertProgressIsZero(jobId);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
assertProgress(jobId, 100, 100, 100, 100);
assertProgressComplete(jobId);
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
}

View File

@ -101,7 +101,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertIsStopped(id);
assertProgress(id, 0, 0, 0, 0);
assertProgressIsZero(id);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -143,7 +143,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
assertThat(scoreOfOutlier, is(greaterThan(scoreOfNonOutlier)));
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(id,
"Created analytics with analysis type [outlier_detection]",
@ -186,7 +186,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertIsStopped(id);
assertProgress(id, 0, 0, 0, 0);
assertProgressIsZero(id);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -201,7 +201,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount));
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(id,
"Created analytics with analysis type [outlier_detection]",
@ -260,7 +260,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertIsStopped(id);
assertProgress(id, 0, 0, 0, 0);
assertProgressIsZero(id);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -285,7 +285,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
assertThat(outlierScore, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(1.0)));
}
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(id,
"Created analytics with analysis type [outlier_detection]",
@ -397,7 +397,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertIsStopped(id);
assertProgress(id, 0, 0, 0, 0);
assertProgressIsZero(id);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -412,7 +412,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions()));
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(id,
"Created analytics with analysis type [outlier_detection]",
@ -458,7 +458,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertIsStopped(id);
assertProgress(id, 0, 0, 0, 0);
assertProgressIsZero(id);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -473,7 +473,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setQuery(QueryBuilders.existsQuery("ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) bulkRequestBuilder.numberOfActions()));
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(id,
"Created analytics with analysis type [outlier_detection]",
@ -650,7 +650,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount));
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
@ -691,7 +691,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
putAnalytics(config);
assertIsStopped(id);
assertProgress(id, 0, 0, 0, 0);
assertProgressIsZero(id);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
@ -730,7 +730,7 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
}
assertThat(scoreOfOutlier, is(greaterThan(scoreOfNonOutlier)));
assertProgress(id, 100, 100, 100, 100);
assertProgressComplete(id);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
assertThatAuditMessagesMatch(id,
"Created analytics with analysis type [outlier_detection]",