The main improvement here is that the total expected count of training rows in the test is calculated as the sum of the training fraction times the cardinality of each class (instead of the training fraction times the total doc count). Also relaxes slightly the error bound on the uniformity test from 0.12 to 0.13. Closes #54122 Backport of #58180
This commit is contained in:
parent
85be78b624
commit
36dbf08d47
|
@ -59,11 +59,11 @@ public class StratifiedCrossValidationSplitter implements CrossValidationSplitte
|
|||
|
||||
// We ensure the target sample count is at least 1 as if the cardinality
|
||||
// is too low we might get a target of zero and, thus, no samples of the whole class
|
||||
double targetSampleCount = Math.max(1.0, samplingRatio * sample.cardinality);
|
||||
long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.cardinality);
|
||||
|
||||
// The idea here is that the probability increases as the chances we have to get the target proportion
|
||||
// for a class decreases.
|
||||
double p = (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
|
||||
double p = (double) (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
|
||||
|
||||
boolean isTraining = random.nextDouble() <= p;
|
||||
|
||||
|
|
|
@ -177,10 +177,13 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
|
||||
// We can assert we're plus/minus 1 from rounding error
|
||||
|
||||
double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction;
|
||||
long expectedTotalTrainingCount = 0;
|
||||
for (long classCardinality : classCardinalities.values()) {
|
||||
expectedTotalTrainingCount += trainingFraction * classCardinality;
|
||||
}
|
||||
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
|
||||
assertThat(trainingDocsCount, greaterThanOrEqualTo((long) (expectedTotalTrainingCount - 2)));
|
||||
assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount) + 2));
|
||||
assertThat(trainingDocsCount, greaterThanOrEqualTo(expectedTotalTrainingCount - 2));
|
||||
assertThat(trainingDocsCount, lessThanOrEqualTo(expectedTotalTrainingCount));
|
||||
|
||||
for (String classValue : classCardinalities.keySet()) {
|
||||
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
|
||||
|
@ -221,7 +224,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
// should be close to the training percent, which is set to 0.5
|
||||
for (int rowTrainingCount : trainingCountPerRow) {
|
||||
double meanCount = rowTrainingCount / (double) runCount;
|
||||
assertThat(meanCount, is(closeTo(0.5, 0.12)));
|
||||
assertThat(meanCount, is(closeTo(0.5, 0.13)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue