[7.x][ML] Improve assertion on regression alias field test (#60221) (#60264)

Previously the test was asserting the prediction on each document
was close 10.0 from the expected. It turned out that was not enough
as we occasionally saw the test failing by little.

Instead of relaxing that assertion, this commit changes it to
assert the mean prediction error is less than 10.0. This should
reduce the chances of the test failing significantly.

Fixes #60212

Backport of #60221
This commit is contained in:
Dimitris Athanasiou 2020-07-28 11:48:00 +03:00 committed by GitHub
parent fac5953d13
commit 981e436d6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 6 deletions

View File

@ -38,13 +38,13 @@ import java.util.Set;
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.emptyString; import static org.hamcrest.Matchers.emptyString;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.not;
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@ -461,7 +461,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
"Finished analysis"); "Finished analysis");
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/60212")
public void testAliasFields() throws Exception { public void testAliasFields() throws Exception {
// The goal of this test is to assert alias fields are included in the analytics job. // The goal of this test is to assert alias fields are included in the analytics job.
// We have a simple dataset with two integer fields: field_1 and field_2. // We have a simple dataset with two integer fields: field_1 and field_2.
@ -528,19 +527,26 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
startAnalytics(jobId); startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId); waitUntilAnalyticsIsStopped(jobId);
double predictionErrorSum = 0.0;
SearchResponse sourceData = client().prepareSearch(sourceIndex).setSize(totalDocCount).get(); SearchResponse sourceData = client().prepareSearch(sourceIndex).setSize(totalDocCount).get();
for (SearchHit hit : sourceData.getHits()) { for (SearchHit hit : sourceData.getHits()) {
Map<String, Object> destDoc = getDestDoc(config, hit); Map<String, Object> destDoc = getDestDoc(config, hit);
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc); Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
int featureValue = (int) destDoc.get("field_1");
double predictionValue = (double) resultsObject.get(predictionField);
assertThat(predictionValue, closeTo(2 * featureValue, 10.0));
assertThat(resultsObject.containsKey(predictionField), is(true)); assertThat(resultsObject.containsKey(predictionField), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true));
int featureValue = (int) destDoc.get("field_1");
double predictionValue = (double) resultsObject.get(predictionField);
predictionErrorSum += Math.abs(predictionValue - 2 * featureValue);
} }
// We assert on the mean prediction error in order to reduce the probability
// the test fails compared to asserting on the prediction of each individual doc.
double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length;
assertThat(meanPredictionError, lessThanOrEqualTo(10.0));
assertProgressComplete(jobId); assertProgressComplete(jobId);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId()); assertModelStatePersisted(stateDocId());