[7.x][ML] Make regression training set predictable in size (#58331) (#58453)

Unlike `classification`, which is using a cross validation splitter
that produces training sets whose size is predictable and equal to
`training_percent * class_cardinality`, for regression we have been
using a random splitter that takes an independent decision for each
document. This means we cannot predict the exact size of the training
set. This poses a problem as we move towards performing test inference
on the java side as we need to be able to provide an accurate upper
bound of the training set size to the c++ process.

This commit replaces the random splitter we use for regression with
the same streaming-reservoir approach we do for `classification`.

Backport of #58331
This commit is contained in:
Dimitris Athanasiou 2020-06-23 19:49:03 +03:00 committed by GitHub
parent e7c40d973e
commit f67fee387b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 171 additions and 176 deletions

View File

@ -0,0 +1,91 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import java.util.List;
import java.util.Random;
/**
* This is a streaming implementation of a cross validation splitter that
* is based on the reservoir idea. It randomly picks training docs while
* respecting the exact training percent.
*/
abstract class AbstractReservoirCrossValidationSplitter implements CrossValidationSplitter {
protected final int dependentVariableIndex;
private final double samplingRatio;
private final Random random;
AbstractReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.samplingRatio = trainingPercent / 100.0;
this.random = new Random(randomizeSeed);
}
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
return dependentVariableIndex;
}
@Override
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
if (canBeUsedForTraining(row) == false) {
incrementTestDocs.run();
return;
}
SampleInfo sample = getSampleInfo(row);
// We ensure the target sample count is at least 1 as if the class count
// is too low we might get a target of zero and, thus, no samples of the whole class
long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.classCount);
// The idea here is that the probability increases as the chances we have to get the target proportion
// for a class decreases.
double p = (double) (targetSampleCount - sample.training) / (sample.classCount - sample.observed);
boolean isTraining = random.nextDouble() <= p;
sample.observed++;
if (isTraining) {
sample.training++;
incrementTrainingDocs.run();
} else {
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
incrementTestDocs.run();
}
}
private boolean canBeUsedForTraining(String[] row) {
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
}
protected abstract SampleInfo getSampleInfo(String[] row);
/**
* Class count, count of docs picked for training, and count of observed
*/
static class SampleInfo {
private final long classCount;
private long training;
private long observed;
SampleInfo(long classCount) {
this.classCount = classCount;
}
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms; import org.elasticsearch.search.aggregations.bucket.terms.Terms;
@ -41,7 +42,7 @@ public class CrossValidationSplitterFactory {
public CrossValidationSplitter create() { public CrossValidationSplitter create() {
if (config.getAnalysis() instanceof Regression) { if (config.getAnalysis() instanceof Regression) {
return createRandomSplitter(); return createSingleClassSplitter((Regression) config.getAnalysis());
} }
if (config.getAnalysis() instanceof Classification) { if (config.getAnalysis() instanceof Classification) {
return createStratifiedSplitter((Classification) config.getAnalysis()); return createStratifiedSplitter((Classification) config.getAnalysis());
@ -49,10 +50,23 @@ public class CrossValidationSplitterFactory {
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run(); return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
} }
private CrossValidationSplitter createRandomSplitter() { private CrossValidationSplitter createSingleClassSplitter(Regression regression) {
Regression regression = (Regression) config.getAnalysis(); SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
return new RandomCrossValidationSplitter( .setSize(0)
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed()); .setAllowPartialSearchResults(false)
.setTrackTotalHits(true)
.setQuery(QueryBuilders.existsQuery(regression.getDependentVariable()));
try {
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
searchRequestBuilder::get);
return new SingleClassReservoirCrossValidationSplitter(fieldNames, regression.getDependentVariable(),
regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value);
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Error searching total number of training docs", config.getId());
LOGGER.error(msg, e);
throw new ElasticsearchException(msg.getFormattedMessage(), e);
}
} }
private CrossValidationSplitter createStratifiedSplitter(Classification classification) { private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
@ -69,12 +83,12 @@ public class CrossValidationSplitterFactory {
searchRequestBuilder::get); searchRequestBuilder::get);
Aggregations aggs = searchResponse.getAggregations(); Aggregations aggs = searchResponse.getAggregations();
Terms terms = aggs.get(aggName); Terms terms = aggs.get(aggName);
Map<String, Long> classCardinalities = new HashMap<>(); Map<String, Long> classCounts = new HashMap<>();
for (Terms.Bucket bucket : terms.getBuckets()) { for (Terms.Bucket bucket : terms.getBuckets()) {
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount()); classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
} }
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities, return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCounts,
classification.getTrainingPercent(), classification.getRandomizeSeed()); classification.getTrainingPercent(), classification.getRandomizeSeed());
} catch (Exception e) { } catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId()); ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());

View File

@ -1,64 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import java.util.List;
import java.util.Random;
/**
* A cross validation splitter that randomly clears the dependent variable value
* in order to split the dataset in training and test data.
* This relies on the fact that when the dependent variable field
* is empty, then the row is not used for training but only to make predictions.
*/
class RandomCrossValidationSplitter implements CrossValidationSplitter {
private final int dependentVariableIndex;
private final double trainingPercent;
private final Random random;
private boolean isFirstRow = true;
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.trainingPercent = trainingPercent;
this.random = new Random(randomizeSeed);
}
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
return dependentVariableIndex;
}
@Override
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
if (canBeUsedForTraining(row) && isPickedForTraining()) {
incrementTrainingDocs.run();
} else {
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
incrementTestDocs.run();
}
}
private boolean canBeUsedForTraining(String[] row) {
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
}
private boolean isPickedForTraining() {
if (isFirstRow) {
// Let's make sure we have at least one training row
isFirstRow = false;
return true;
}
return random.nextDouble() * 100 <= trainingPercent;
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
import java.util.List;
public class SingleClassReservoirCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
private final SampleInfo sampleInfo;
SingleClassReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed, long classCount) {
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
sampleInfo = new SampleInfo(classCount);
}
@Override
protected SampleInfo getSampleInfo(String[] row) {
return sampleInfo;
}
}

View File

@ -6,89 +6,32 @@
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Random;
/** /**
* Given a dependent variable, randomly splits the dataset trying * Given a dependent variable, randomly splits the dataset trying
* to preserve the proportion of each class in the training sample. * to preserve the proportion of each class in the training sample.
*/ */
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter { public class StratifiedCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
private final int dependentVariableIndex; private final Map<String, SampleInfo> classSamples;
private final double samplingRatio;
private final Random random;
private final Map<String, ClassSample> classSamples;
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities, public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
double trainingPercent, long randomizeSeed) { double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0; super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.samplingRatio = trainingPercent / 100.0;
this.random = new Random(randomizeSeed);
this.classSamples = new HashMap<>(); this.classSamples = new HashMap<>();
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue()))); classCounts.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new SampleInfo(entry.getValue())));
}
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
return dependentVariableIndex;
} }
@Override @Override
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { protected SampleInfo getSampleInfo(String[] row) {
if (canBeUsedForTraining(row) == false) {
incrementTestDocs.run();
return;
}
String classValue = row[dependentVariableIndex]; String classValue = row[dependentVariableIndex];
ClassSample sample = classSamples.get(classValue); SampleInfo sample = classSamples.get(classValue);
if (sample == null) { if (sample == null) {
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet()); throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
} }
return sample;
// We ensure the target sample count is at least 1 as if the cardinality
// is too low we might get a target of zero and, thus, no samples of the whole class
long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.cardinality);
// The idea here is that the probability increases as the chances we have to get the target proportion
// for a class decreases.
double p = (double) (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
boolean isTraining = random.nextDouble() <= p;
sample.observed++;
if (isTraining) {
sample.training++;
incrementTrainingDocs.run();
} else {
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
incrementTestDocs.run();
}
}
private boolean canBeUsedForTraining(String[] row) {
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
}
private static class ClassSample {
private final long cardinality;
private long training;
private long observed;
private ClassSample(long cardinality) {
this.cardinality = cardinality;
}
} }
} }

View File

@ -15,12 +15,10 @@ import java.util.List;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.both; import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
public class RandomCrossValidationSplitterTests extends ESTestCase { public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase {
private List<String> fields; private List<String> fields;
private int dependentVariableIndex; private int dependentVariableIndex;
@ -42,7 +40,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
} }
public void testProcess_GivenRowsWithoutDependentVariableValue() { public void testProcess_GivenRowsWithoutDependentVariableValue() {
CrossValidationSplitter crossValidationSplitter = createSplitter(50.0); CrossValidationSplitter crossValidationSplitter = createSplitter(50.0, 0);
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()]; String[] row = new String[fields.size()];
@ -62,7 +60,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
} }
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CrossValidationSplitter crossValidationSplitter = createSplitter(100.0); CrossValidationSplitter crossValidationSplitter = createSplitter(100.0, 100L);
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()]; String[] row = new String[fields.size()];
@ -83,14 +81,14 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
double trainingFraction = trainingPercent / 100; double trainingFraction = trainingPercent / 100;
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent); long rowCount = 1000;
int runCount = 20; int runCount = 20;
int rowsCount = 1000;
int[] trainingRowsPerRun = new int[runCount]; int[] trainingRowsPerRun = new int[runCount];
for (int testIndex = 0; testIndex < runCount; testIndex++) { for (int testIndex = 0; testIndex < runCount; testIndex++) {
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent, rowCount);
int trainingRows = 0; int trainingRows = 0;
for (int i = 0; i < rowsCount; i++) { for (int i = 0; i < rowCount; i++) {
String[] row = new String[fields.size()]; String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
row[fieldIndex] = randomAlphaOfLength(10); row[fieldIndex] = randomAlphaOfLength(10);
@ -113,23 +111,11 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
} }
double meanTrainingRows = IntStream.of(trainingRowsPerRun).average().getAsDouble(); double meanTrainingRows = IntStream.of(trainingRowsPerRun).average().getAsDouble();
assertThat(meanTrainingRows, closeTo(trainingFraction * rowCount, 1.0));
// Now we need to calculate sensible bounds to assert against.
// We'll use 5 variances which should mean the test only fails once in 7M
// And, because we're doing multiple runs, we'll divide the variance with the number of runs to narrow the bounds
double expectedTrainingRows = trainingFraction * rowsCount;
double variance = rowsCount * (Math.pow(1 - trainingFraction, 2) * trainingFraction
+ Math.pow(trainingFraction, 2) * (1 - trainingFraction));
double lowerBound = expectedTrainingRows - 5 * Math.sqrt(variance / runCount);
double upperBound = expectedTrainingRows + 5 * Math.sqrt(variance / runCount);
assertThat("Mean training rows [" + meanTrainingRows + "] was not within expected bounds of [" + lowerBound + ", "
+ upperBound + "] given training fraction was [" + trainingFraction + "]",
meanTrainingRows, is(both(greaterThan(lowerBound)).and(lessThan(upperBound))));
} }
public void testProcess_ShouldHaveAtLeastOneTrainingRow() { public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
CrossValidationSplitter crossValidationSplitter = createSplitter(1.0); CrossValidationSplitter crossValidationSplitter = createSplitter(1.0, 1);
// We have some non-training rows and then a training row to check // We have some non-training rows and then a training row to check
// we maintain the first training row and not just the first row // we maintain the first training row and not just the first row
@ -152,8 +138,8 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
assertThat(testDocsCount, equalTo(9L)); assertThat(testDocsCount, equalTo(9L));
} }
private RandomCrossValidationSplitter createSplitter(double trainingPercent) { private CrossValidationSplitter createSplitter(double trainingPercent, long classCount) {
return new RandomCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed); return new SingleClassReservoirCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount);
} }
private void incrementTrainingDocsCount() { private void incrementTrainingDocsCount() {

View File

@ -32,7 +32,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
private int dependentVariableIndex; private int dependentVariableIndex;
private String dependentVariable; private String dependentVariable;
private long randomizeSeed; private long randomizeSeed;
private Map<String, Long> classCardinalities; private Map<String, Long> classCounts;
private String[] classValuesPerRow; private String[] classValuesPerRow;
private long trainingDocsCount; private long trainingDocsCount;
private long testDocsCount; private long testDocsCount;
@ -68,10 +68,10 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
} }
} }
classCardinalities = new HashMap<>(); classCounts = new HashMap<>();
classCardinalities.put("a", classA); classCounts.put("a", classA);
classCardinalities.put("b", classB); classCounts.put("b", classB);
classCardinalities.put("c", classC); classCounts.put("c", classC);
} }
public void testConstructor_GivenMissingDependentVariable() { public void testConstructor_GivenMissingDependentVariable() {
@ -143,7 +143,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
Map<String, Integer> totalRowsPerClass = new HashMap<>(); Map<String, Integer> totalRowsPerClass = new HashMap<>();
Map<String, Integer> trainingRowsPerClass = new HashMap<>(); Map<String, Integer> trainingRowsPerClass = new HashMap<>();
for (String classValue : classCardinalities.keySet()) { for (String classValue : classCounts.keySet()) {
totalRowsPerClass.put(classValue, 0); totalRowsPerClass.put(classValue, 0);
trainingRowsPerClass.put(classValue, 0); trainingRowsPerClass.put(classValue, 0);
} }
@ -178,14 +178,14 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
// We can assert we're plus/minus 1 from rounding error // We can assert we're plus/minus 1 from rounding error
long expectedTotalTrainingCount = 0; long expectedTotalTrainingCount = 0;
for (long classCardinality : classCardinalities.values()) { for (long classCount : classCounts.values()) {
expectedTotalTrainingCount += trainingFraction * classCardinality; expectedTotalTrainingCount += trainingFraction * classCount;
} }
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT)); assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
assertThat(trainingDocsCount, greaterThanOrEqualTo(expectedTotalTrainingCount - 2)); assertThat(trainingDocsCount, greaterThanOrEqualTo(expectedTotalTrainingCount - 2));
assertThat(trainingDocsCount, lessThanOrEqualTo(expectedTotalTrainingCount)); assertThat(trainingDocsCount, lessThanOrEqualTo(expectedTotalTrainingCount));
for (String classValue : classCardinalities.keySet()) { for (String classValue : classCounts.keySet()) {
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction; double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
int classTrainingCount = trainingRowsPerClass.get(classValue); int classTrainingCount = trainingRowsPerClass.get(classValue);
assertThat((double) classTrainingCount, is(closeTo(expectedClassTrainingCount, 1.0))); assertThat((double) classTrainingCount, is(closeTo(expectedClassTrainingCount, 1.0)));
@ -228,12 +228,12 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
} }
} }
public void testProcess_GivenTwoClassesWithCardinalityEqualToOne_ShouldUseForTraining() { public void testProcess_GivenTwoClassesWithCountEqualToOne_ShouldUseForTraining() {
dependentVariable = "dep_var"; dependentVariable = "dep_var";
fields = Arrays.asList(dependentVariable, "feature"); fields = Arrays.asList(dependentVariable, "feature");
classCardinalities = new HashMap<>(); classCounts = new HashMap<>();
classCardinalities.put("class_a", 1L); classCounts.put("class_a", 1L);
classCardinalities.put("class_b", 1L); classCounts.put("class_b", 1L);
CrossValidationSplitter splitter = createSplitter(80.0); CrossValidationSplitter splitter = createSplitter(80.0);
{ {
@ -258,7 +258,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
} }
private CrossValidationSplitter createSplitter(double trainingPercent) { private CrossValidationSplitter createSplitter(double trainingPercent) {
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed); return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed);
} }
private void incrementTrainingDocsCount() { private void incrementTrainingDocsCount() {