Unlike `classification`, which is using a cross validation splitter that produces training sets whose size is predictable and equal to `training_percent * class_cardinality`, for regression we have been using a random splitter that takes an independent decision for each document. This means we cannot predict the exact size of the training set. This poses a problem as we move towards performing test inference on the java side as we need to be able to provide an accurate upper bound of the training set size to the c++ process. This commit replaces the random splitter we use for regression with the same streaming-reservoir approach we do for `classification`. Backport of #58331
This commit is contained in:
parent
e7c40d973e
commit
f67fee387b
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* This is a streaming implementation of a cross validation splitter that
|
||||
* is based on the reservoir idea. It randomly picks training docs while
|
||||
* respecting the exact training percent.
|
||||
*/
|
||||
abstract class AbstractReservoirCrossValidationSplitter implements CrossValidationSplitter {
|
||||
|
||||
protected final int dependentVariableIndex;
|
||||
private final double samplingRatio;
|
||||
private final Random random;
|
||||
|
||||
AbstractReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
|
||||
long randomizeSeed) {
|
||||
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.samplingRatio = trainingPercent / 100.0;
|
||||
this.random = new Random(randomizeSeed);
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
|
||||
if (dependentVariableIndex < 0) {
|
||||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||
}
|
||||
return dependentVariableIndex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||
|
||||
if (canBeUsedForTraining(row) == false) {
|
||||
incrementTestDocs.run();
|
||||
return;
|
||||
}
|
||||
|
||||
SampleInfo sample = getSampleInfo(row);
|
||||
|
||||
// We ensure the target sample count is at least 1 as if the class count
|
||||
// is too low we might get a target of zero and, thus, no samples of the whole class
|
||||
long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.classCount);
|
||||
|
||||
// The idea here is that the probability increases as the chances we have to get the target proportion
|
||||
// for a class decreases.
|
||||
double p = (double) (targetSampleCount - sample.training) / (sample.classCount - sample.observed);
|
||||
|
||||
boolean isTraining = random.nextDouble() <= p;
|
||||
|
||||
sample.observed++;
|
||||
if (isTraining) {
|
||||
sample.training++;
|
||||
incrementTrainingDocs.run();
|
||||
} else {
|
||||
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
|
||||
incrementTestDocs.run();
|
||||
}
|
||||
}
|
||||
|
||||
private boolean canBeUsedForTraining(String[] row) {
|
||||
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
|
||||
}
|
||||
|
||||
protected abstract SampleInfo getSampleInfo(String[] row);
|
||||
|
||||
/**
|
||||
* Class count, count of docs picked for training, and count of observed
|
||||
*/
|
||||
static class SampleInfo {
|
||||
|
||||
private final long classCount;
|
||||
private long training;
|
||||
private long observed;
|
||||
|
||||
SampleInfo(long classCount) {
|
||||
this.classCount = classCount;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchException;
|
|||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
|
@ -41,7 +42,7 @@ public class CrossValidationSplitterFactory {
|
|||
|
||||
public CrossValidationSplitter create() {
|
||||
if (config.getAnalysis() instanceof Regression) {
|
||||
return createRandomSplitter();
|
||||
return createSingleClassSplitter((Regression) config.getAnalysis());
|
||||
}
|
||||
if (config.getAnalysis() instanceof Classification) {
|
||||
return createStratifiedSplitter((Classification) config.getAnalysis());
|
||||
|
@ -49,10 +50,23 @@ public class CrossValidationSplitterFactory {
|
|||
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createRandomSplitter() {
|
||||
Regression regression = (Regression) config.getAnalysis();
|
||||
return new RandomCrossValidationSplitter(
|
||||
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
|
||||
private CrossValidationSplitter createSingleClassSplitter(Regression regression) {
|
||||
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
|
||||
.setSize(0)
|
||||
.setAllowPartialSearchResults(false)
|
||||
.setTrackTotalHits(true)
|
||||
.setQuery(QueryBuilders.existsQuery(regression.getDependentVariable()));
|
||||
|
||||
try {
|
||||
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
|
||||
searchRequestBuilder::get);
|
||||
return new SingleClassReservoirCrossValidationSplitter(fieldNames, regression.getDependentVariable(),
|
||||
regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value);
|
||||
} catch (Exception e) {
|
||||
ParameterizedMessage msg = new ParameterizedMessage("[{}] Error searching total number of training docs", config.getId());
|
||||
LOGGER.error(msg, e);
|
||||
throw new ElasticsearchException(msg.getFormattedMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
|
||||
|
@ -69,12 +83,12 @@ public class CrossValidationSplitterFactory {
|
|||
searchRequestBuilder::get);
|
||||
Aggregations aggs = searchResponse.getAggregations();
|
||||
Terms terms = aggs.get(aggName);
|
||||
Map<String, Long> classCardinalities = new HashMap<>();
|
||||
Map<String, Long> classCounts = new HashMap<>();
|
||||
for (Terms.Bucket bucket : terms.getBuckets()) {
|
||||
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
|
||||
classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
|
||||
}
|
||||
|
||||
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities,
|
||||
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCounts,
|
||||
classification.getTrainingPercent(), classification.getRandomizeSeed());
|
||||
} catch (Exception e) {
|
||||
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* A cross validation splitter that randomly clears the dependent variable value
|
||||
* in order to split the dataset in training and test data.
|
||||
* This relies on the fact that when the dependent variable field
|
||||
* is empty, then the row is not used for training but only to make predictions.
|
||||
*/
|
||||
class RandomCrossValidationSplitter implements CrossValidationSplitter {
|
||||
|
||||
private final int dependentVariableIndex;
|
||||
private final double trainingPercent;
|
||||
private final Random random;
|
||||
private boolean isFirstRow = true;
|
||||
|
||||
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
|
||||
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.random = new Random(randomizeSeed);
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
|
||||
if (dependentVariableIndex < 0) {
|
||||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||
}
|
||||
return dependentVariableIndex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||
if (canBeUsedForTraining(row) && isPickedForTraining()) {
|
||||
incrementTrainingDocs.run();
|
||||
} else {
|
||||
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
|
||||
incrementTestDocs.run();
|
||||
}
|
||||
}
|
||||
|
||||
private boolean canBeUsedForTraining(String[] row) {
|
||||
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
|
||||
}
|
||||
|
||||
private boolean isPickedForTraining() {
|
||||
if (isFirstRow) {
|
||||
// Let's make sure we have at least one training row
|
||||
isFirstRow = false;
|
||||
return true;
|
||||
}
|
||||
return random.nextDouble() * 100 <= trainingPercent;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SingleClassReservoirCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
|
||||
|
||||
private final SampleInfo sampleInfo;
|
||||
|
||||
SingleClassReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
|
||||
long randomizeSeed, long classCount) {
|
||||
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
|
||||
sampleInfo = new SampleInfo(classCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected SampleInfo getSampleInfo(String[] row) {
|
||||
return sampleInfo;
|
||||
}
|
||||
}
|
|
@ -6,89 +6,32 @@
|
|||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* Given a dependent variable, randomly splits the dataset trying
|
||||
* to preserve the proportion of each class in the training sample.
|
||||
*/
|
||||
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter {
|
||||
public class StratifiedCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
|
||||
|
||||
private final int dependentVariableIndex;
|
||||
private final double samplingRatio;
|
||||
private final Random random;
|
||||
private final Map<String, ClassSample> classSamples;
|
||||
private final Map<String, SampleInfo> classSamples;
|
||||
|
||||
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities,
|
||||
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
|
||||
double trainingPercent, long randomizeSeed) {
|
||||
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.samplingRatio = trainingPercent / 100.0;
|
||||
this.random = new Random(randomizeSeed);
|
||||
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
|
||||
this.classSamples = new HashMap<>();
|
||||
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue())));
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
|
||||
if (dependentVariableIndex < 0) {
|
||||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||
}
|
||||
return dependentVariableIndex;
|
||||
classCounts.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new SampleInfo(entry.getValue())));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||
|
||||
if (canBeUsedForTraining(row) == false) {
|
||||
incrementTestDocs.run();
|
||||
return;
|
||||
}
|
||||
|
||||
protected SampleInfo getSampleInfo(String[] row) {
|
||||
String classValue = row[dependentVariableIndex];
|
||||
ClassSample sample = classSamples.get(classValue);
|
||||
SampleInfo sample = classSamples.get(classValue);
|
||||
if (sample == null) {
|
||||
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
|
||||
}
|
||||
|
||||
// We ensure the target sample count is at least 1 as if the cardinality
|
||||
// is too low we might get a target of zero and, thus, no samples of the whole class
|
||||
long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.cardinality);
|
||||
|
||||
// The idea here is that the probability increases as the chances we have to get the target proportion
|
||||
// for a class decreases.
|
||||
double p = (double) (targetSampleCount - sample.training) / (sample.cardinality - sample.observed);
|
||||
|
||||
boolean isTraining = random.nextDouble() <= p;
|
||||
|
||||
sample.observed++;
|
||||
if (isTraining) {
|
||||
sample.training++;
|
||||
incrementTrainingDocs.run();
|
||||
} else {
|
||||
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
|
||||
incrementTestDocs.run();
|
||||
}
|
||||
}
|
||||
|
||||
private boolean canBeUsedForTraining(String[] row) {
|
||||
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
|
||||
}
|
||||
|
||||
private static class ClassSample {
|
||||
|
||||
private final long cardinality;
|
||||
private long training;
|
||||
private long observed;
|
||||
|
||||
private ClassSample(long cardinality) {
|
||||
this.cardinality = cardinality;
|
||||
}
|
||||
return sample;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,12 +15,10 @@ import java.util.List;
|
|||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.Matchers.both;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
|
||||
public class RandomCrossValidationSplitterTests extends ESTestCase {
|
||||
public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase {
|
||||
|
||||
private List<String> fields;
|
||||
private int dependentVariableIndex;
|
||||
|
@ -42,7 +40,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(50.0);
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(50.0, 0);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -62,7 +60,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(100.0);
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(100.0, 100L);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -83,14 +81,14 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
|||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
||||
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
|
||||
double trainingFraction = trainingPercent / 100;
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent);
|
||||
long rowCount = 1000;
|
||||
|
||||
int runCount = 20;
|
||||
int rowsCount = 1000;
|
||||
int[] trainingRowsPerRun = new int[runCount];
|
||||
for (int testIndex = 0; testIndex < runCount; testIndex++) {
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent, rowCount);
|
||||
int trainingRows = 0;
|
||||
for (int i = 0; i < rowsCount; i++) {
|
||||
for (int i = 0; i < rowCount; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
row[fieldIndex] = randomAlphaOfLength(10);
|
||||
|
@ -113,23 +111,11 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
double meanTrainingRows = IntStream.of(trainingRowsPerRun).average().getAsDouble();
|
||||
|
||||
// Now we need to calculate sensible bounds to assert against.
|
||||
// We'll use 5 variances which should mean the test only fails once in 7M
|
||||
// And, because we're doing multiple runs, we'll divide the variance with the number of runs to narrow the bounds
|
||||
double expectedTrainingRows = trainingFraction * rowsCount;
|
||||
double variance = rowsCount * (Math.pow(1 - trainingFraction, 2) * trainingFraction
|
||||
+ Math.pow(trainingFraction, 2) * (1 - trainingFraction));
|
||||
double lowerBound = expectedTrainingRows - 5 * Math.sqrt(variance / runCount);
|
||||
double upperBound = expectedTrainingRows + 5 * Math.sqrt(variance / runCount);
|
||||
|
||||
assertThat("Mean training rows [" + meanTrainingRows + "] was not within expected bounds of [" + lowerBound + ", "
|
||||
+ upperBound + "] given training fraction was [" + trainingFraction + "]",
|
||||
meanTrainingRows, is(both(greaterThan(lowerBound)).and(lessThan(upperBound))));
|
||||
assertThat(meanTrainingRows, closeTo(trainingFraction * rowCount, 1.0));
|
||||
}
|
||||
|
||||
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(1.0);
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(1.0, 1);
|
||||
|
||||
// We have some non-training rows and then a training row to check
|
||||
// we maintain the first training row and not just the first row
|
||||
|
@ -152,8 +138,8 @@ public class RandomCrossValidationSplitterTests extends ESTestCase {
|
|||
assertThat(testDocsCount, equalTo(9L));
|
||||
}
|
||||
|
||||
private RandomCrossValidationSplitter createSplitter(double trainingPercent) {
|
||||
return new RandomCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed);
|
||||
private CrossValidationSplitter createSplitter(double trainingPercent, long classCount) {
|
||||
return new SingleClassReservoirCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount);
|
||||
}
|
||||
|
||||
private void incrementTrainingDocsCount() {
|
|
@ -32,7 +32,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
private int dependentVariableIndex;
|
||||
private String dependentVariable;
|
||||
private long randomizeSeed;
|
||||
private Map<String, Long> classCardinalities;
|
||||
private Map<String, Long> classCounts;
|
||||
private String[] classValuesPerRow;
|
||||
private long trainingDocsCount;
|
||||
private long testDocsCount;
|
||||
|
@ -68,10 +68,10 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
classCardinalities = new HashMap<>();
|
||||
classCardinalities.put("a", classA);
|
||||
classCardinalities.put("b", classB);
|
||||
classCardinalities.put("c", classC);
|
||||
classCounts = new HashMap<>();
|
||||
classCounts.put("a", classA);
|
||||
classCounts.put("b", classB);
|
||||
classCounts.put("c", classC);
|
||||
}
|
||||
|
||||
public void testConstructor_GivenMissingDependentVariable() {
|
||||
|
@ -143,7 +143,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
Map<String, Integer> totalRowsPerClass = new HashMap<>();
|
||||
Map<String, Integer> trainingRowsPerClass = new HashMap<>();
|
||||
|
||||
for (String classValue : classCardinalities.keySet()) {
|
||||
for (String classValue : classCounts.keySet()) {
|
||||
totalRowsPerClass.put(classValue, 0);
|
||||
trainingRowsPerClass.put(classValue, 0);
|
||||
}
|
||||
|
@ -178,14 +178,14 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
// We can assert we're plus/minus 1 from rounding error
|
||||
|
||||
long expectedTotalTrainingCount = 0;
|
||||
for (long classCardinality : classCardinalities.values()) {
|
||||
expectedTotalTrainingCount += trainingFraction * classCardinality;
|
||||
for (long classCount : classCounts.values()) {
|
||||
expectedTotalTrainingCount += trainingFraction * classCount;
|
||||
}
|
||||
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
|
||||
assertThat(trainingDocsCount, greaterThanOrEqualTo(expectedTotalTrainingCount - 2));
|
||||
assertThat(trainingDocsCount, lessThanOrEqualTo(expectedTotalTrainingCount));
|
||||
|
||||
for (String classValue : classCardinalities.keySet()) {
|
||||
for (String classValue : classCounts.keySet()) {
|
||||
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
|
||||
int classTrainingCount = trainingRowsPerClass.get(classValue);
|
||||
assertThat((double) classTrainingCount, is(closeTo(expectedClassTrainingCount, 1.0)));
|
||||
|
@ -228,12 +228,12 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenTwoClassesWithCardinalityEqualToOne_ShouldUseForTraining() {
|
||||
public void testProcess_GivenTwoClassesWithCountEqualToOne_ShouldUseForTraining() {
|
||||
dependentVariable = "dep_var";
|
||||
fields = Arrays.asList(dependentVariable, "feature");
|
||||
classCardinalities = new HashMap<>();
|
||||
classCardinalities.put("class_a", 1L);
|
||||
classCardinalities.put("class_b", 1L);
|
||||
classCounts = new HashMap<>();
|
||||
classCounts.put("class_a", 1L);
|
||||
classCounts.put("class_b", 1L);
|
||||
CrossValidationSplitter splitter = createSplitter(80.0);
|
||||
|
||||
{
|
||||
|
@ -258,7 +258,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
private CrossValidationSplitter createSplitter(double trainingPercent) {
|
||||
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed);
|
||||
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed);
|
||||
}
|
||||
|
||||
private void incrementTrainingDocsCount() {
|
||||
|
|
Loading…
Reference in New Issue