diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index fc22edb5d1c..31f70e02547 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -38,13 +38,13 @@ import java.util.Set; import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent; import static org.hamcrest.Matchers.anyOf; -import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.emptyString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -461,7 +461,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { "Finished analysis"); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/60212") public void testAliasFields() throws Exception { // The goal of this test is to assert alias fields are included in the analytics job. // We have a simple dataset with two integer fields: field_1 and field_2. @@ -528,19 +527,26 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { startAnalytics(jobId); waitUntilAnalyticsIsStopped(jobId); + double predictionErrorSum = 0.0; + SearchResponse sourceData = client().prepareSearch(sourceIndex).setSize(totalDocCount).get(); for (SearchHit hit : sourceData.getHits()) { Map destDoc = getDestDoc(config, hit); Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); - int featureValue = (int) destDoc.get("field_1"); - double predictionValue = (double) resultsObject.get(predictionField); - assertThat(predictionValue, closeTo(2 * featureValue, 10.0)); - assertThat(resultsObject.containsKey(predictionField), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); + + int featureValue = (int) destDoc.get("field_1"); + double predictionValue = (double) resultsObject.get(predictionField); + predictionErrorSum += Math.abs(predictionValue - 2 * featureValue); } + // We assert on the mean prediction error in order to reduce the probability + // the test fails compared to asserting on the prediction of each individual doc. + double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length; + assertThat(meanPredictionError, lessThanOrEqualTo(10.0)); + assertProgressComplete(jobId); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId());