[7.x][ML] Improve stability of stratified splitter tests (#58180) (#58224)

The main improvement here is that the total expected
count of training rows in the test is calculated as the
sum of the training fraction times the cardinality of each
class (instead of the training fraction times the total doc count).

Also relaxes slightly the error bound on the uniformity test from 0.12
to 0.13.

Closes #54122

Backport of #58180
This commit is contained in:
Dimitris Athanasiou 2020-06-17 12:40:21 +03:00 committed by GitHub
parent 85be78b624
commit 36dbf08d47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 6 deletions

View File

@ -59,11 +59,11 @@ public class StratifiedCrossValidationSplitter implements CrossValidationSplitte
// We ensure the target sample count is at least 1 as if the cardinality
// is too low we might get a target of zero and, thus, no samples of the whole class
double targetSampleCount = Math.max(1.0, samplingRatio * sample.cardinality);
long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.cardinality);
// The idea here is that the probability increases as the chances we have to get the target proportion
// for a class decreases.
double p = (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
double p = (double) (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
boolean isTraining = random.nextDouble() <= p;

View File

@ -177,10 +177,13 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
// We can assert we're plus/minus 1 from rounding error
double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction;
long expectedTotalTrainingCount = 0;
for (long classCardinality : classCardinalities.values()) {
expectedTotalTrainingCount += trainingFraction * classCardinality;
}
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
assertThat(trainingDocsCount, greaterThanOrEqualTo((long) (expectedTotalTrainingCount - 2)));
assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount) + 2));
assertThat(trainingDocsCount, greaterThanOrEqualTo(expectedTotalTrainingCount - 2));
assertThat(trainingDocsCount, lessThanOrEqualTo(expectedTotalTrainingCount));
for (String classValue : classCardinalities.keySet()) {
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
@ -221,7 +224,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
// should be close to the training percent, which is set to 0.5
for (int rowTrainingCount : trainingCountPerRow) {
double meanCount = rowTrainingCount / (double) runCount;
assertThat(meanCount, is(closeTo(0.5, 0.12)));
assertThat(meanCount, is(closeTo(0.5, 0.13)));
}
}