diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java index 417ce0a83ff..9ef0f773ee8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java @@ -179,8 +179,8 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction; assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT)); - assertThat(trainingDocsCount, greaterThanOrEqualTo((long) Math.floor(expectedTotalTrainingCount - 1))); - assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount + 1))); + assertThat(trainingDocsCount, greaterThanOrEqualTo((long) (expectedTotalTrainingCount - 2))); + assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount) + 2)); for (String classValue : classCardinalities.keySet()) { double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction; @@ -221,7 +221,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { // should be close to the training percent, which is set to 0.5 for (int rowTrainingCount : trainingCountPerRow) { double meanCount = rowTrainingCount / (double) runCount; - assertThat(meanCount, is(closeTo(0.5, 0.1))); + assertThat(meanCount, is(closeTo(0.5, 0.12))); } }