mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-03-28 10:58:30 +00:00
In DF analytics classification, it is possible to use no samples of a class if its cardinality is too low. This commit fixes this by ensuring the target sample count can never be zero. Backport of #56783
This commit is contained in:
parent
14ad733bd1
commit
54d3cc74ec
x-pack/plugin/ml/src
main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation
test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation
@ -57,9 +57,13 @@ public class StratifiedCrossValidationSplitter implements CrossValidationSplitte
|
||||
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
|
||||
}
|
||||
|
||||
// We ensure the target sample count is at least 1 as if the cardinality
|
||||
// is too low we might get a target of zero and, thus, no samples of the whole class
|
||||
double targetSampleCount = Math.max(1.0, samplingRatio * sample.cardinality);
|
||||
|
||||
// The idea here is that the probability increases as the chances we have to get the target proportion
|
||||
// for a class decreases.
|
||||
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);
|
||||
double p = (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
|
||||
|
||||
boolean isTraining = random.nextDouble() <= p;
|
||||
|
||||
|
@ -225,6 +225,35 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenTwoClassesWithCardinalityEqualToOne_ShouldUseForTraining() {
|
||||
dependentVariable = "dep_var";
|
||||
fields = Arrays.asList(dependentVariable, "feature");
|
||||
classCardinalities = new HashMap<>();
|
||||
classCardinalities.put("class_a", 1L);
|
||||
classCardinalities.put("class_b", 1L);
|
||||
CrossValidationSplitter splitter = createSplitter(80.0);
|
||||
|
||||
{
|
||||
String[] row = new String[]{"class_a", "42.0"};
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
{
|
||||
String[] row = new String[]{"class_b", "42.0"};
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
|
||||
assertThat(trainingDocsCount, equalTo(2L));
|
||||
assertThat(testDocsCount, equalTo(0L));
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createSplitter(double trainingPercent) {
|
||||
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user