Previously the test was asserting the prediction on each document was close 10.0 from the expected. It turned out that was not enough as we occasionally saw the test failing by little. Instead of relaxing that assertion, this commit changes it to assert the mean prediction error is less than 10.0. This should reduce the chances of the test failing significantly. Fixes #60212 Backport of #60221
This commit is contained in:
parent
fac5953d13
commit
981e436d6c
|
@ -38,13 +38,13 @@ import java.util.Set;
|
|||
|
||||
import static org.elasticsearch.test.hamcrest.OptionalMatchers.isPresent;
|
||||
import static org.hamcrest.Matchers.anyOf;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.emptyString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
@ -461,7 +461,6 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
"Finished analysis");
|
||||
}
|
||||
|
||||
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/60212")
|
||||
public void testAliasFields() throws Exception {
|
||||
// The goal of this test is to assert alias fields are included in the analytics job.
|
||||
// We have a simple dataset with two integer fields: field_1 and field_2.
|
||||
|
@ -528,19 +527,26 @@ public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
|||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
double predictionErrorSum = 0.0;
|
||||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setSize(totalDocCount).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
Map<String, Object> destDoc = getDestDoc(config, hit);
|
||||
Map<String, Object> resultsObject = getMlResultsObjectFromDestDoc(destDoc);
|
||||
|
||||
int featureValue = (int) destDoc.get("field_1");
|
||||
double predictionValue = (double) resultsObject.get(predictionField);
|
||||
assertThat(predictionValue, closeTo(2 * featureValue, 10.0));
|
||||
|
||||
assertThat(resultsObject.containsKey(predictionField), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
|
||||
int featureValue = (int) destDoc.get("field_1");
|
||||
double predictionValue = (double) resultsObject.get(predictionField);
|
||||
predictionErrorSum += Math.abs(predictionValue - 2 * featureValue);
|
||||
}
|
||||
|
||||
// We assert on the mean prediction error in order to reduce the probability
|
||||
// the test fails compared to asserting on the prediction of each individual doc.
|
||||
double meanPredictionError = predictionErrorSum / sourceData.getHits().getHits().length;
|
||||
assertThat(meanPredictionError, lessThanOrEqualTo(10.0));
|
||||
|
||||
assertProgressComplete(jobId);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
assertModelStatePersisted(stateDocId());
|
||||
|
|
Loading…
Reference in New Issue