[7.x][ML] Ensure class is represented when its cardinality is low () ()

In DF analytics classification, it is possible to use no samples
of a class if its cardinality is too low.

This commit fixes this by ensuring the target sample count can never be zero.

Backport of 
This commit is contained in:
Dimitris Athanasiou 2020-05-15 20:52:06 +03:00 committed by GitHub
parent 14ad733bd1
commit 54d3cc74ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 1 deletions
x-pack/plugin/ml/src
main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation
test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation

@ -57,9 +57,13 @@ public class StratifiedCrossValidationSplitter implements CrossValidationSplitte
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
}
// We ensure the target sample count is at least 1 as if the cardinality
// is too low we might get a target of zero and, thus, no samples of the whole class
double targetSampleCount = Math.max(1.0, samplingRatio * sample.cardinality);
// The idea here is that the probability increases as the chances we have to get the target proportion
// for a class decreases.
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);
double p = (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
boolean isTraining = random.nextDouble() <= p;

@ -225,6 +225,35 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
}
}
public void testProcess_GivenTwoClassesWithCardinalityEqualToOne_ShouldUseForTraining() {
dependentVariable = "dep_var";
fields = Arrays.asList(dependentVariable, "feature");
classCardinalities = new HashMap<>();
classCardinalities.put("class_a", 1L);
classCardinalities.put("class_b", 1L);
CrossValidationSplitter splitter = createSplitter(80.0);
{
String[] row = new String[]{"class_a", "42.0"};
String[] processedRow = Arrays.copyOf(row, row.length);
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
assertThat(Arrays.equals(processedRow, row), is(true));
}
{
String[] row = new String[]{"class_b", "42.0"};
String[] processedRow = Arrays.copyOf(row, row.length);
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
assertThat(Arrays.equals(processedRow, row), is(true));
}
assertThat(trainingDocsCount, equalTo(2L));
assertThat(testDocsCount, equalTo(0L));
}
private CrossValidationSplitter createSplitter(double trainingPercent) {
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed);
}