From f67fee387b16a7a6b4a8b8055f8b8a9500f5446e Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 23 Jun 2020 19:49:03 +0300 Subject: [PATCH] [7.x][ML] Make regression training set predictable in size (#58331) (#58453) Unlike `classification`, which is using a cross validation splitter that produces training sets whose size is predictable and equal to `training_percent * class_cardinality`, for regression we have been using a random splitter that takes an independent decision for each document. This means we cannot predict the exact size of the training set. This poses a problem as we move towards performing test inference on the java side as we need to be able to provide an accurate upper bound of the training set size to the c++ process. This commit replaces the random splitter we use for regression with the same streaming-reservoir approach we do for `classification`. Backport of #58331 --- ...tractReservoirCrossValidationSplitter.java | 91 +++++++++++++++++++ .../CrossValidationSplitterFactory.java | 30 ++++-- .../RandomCrossValidationSplitter.java | 64 ------------- ...ClassReservoirCrossValidationSplitter.java | 25 +++++ .../StratifiedCrossValidationSplitter.java | 73 ++------------- ...eservoirCrossValidationSplitterTests.java} | 36 +++----- ...tratifiedCrossValidationSplitterTests.java | 28 +++--- 7 files changed, 171 insertions(+), 176 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/AbstractReservoirCrossValidationSplitter.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitter.java rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/{RandomCrossValidationSplitterTests.java => SingleClassReservoirCrossValidationSplitterTests.java} (78%) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/AbstractReservoirCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/AbstractReservoirCrossValidationSplitter.java new file mode 100644 index 00000000000..33837c0f622 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/AbstractReservoirCrossValidationSplitter.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; + +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; + +import java.util.List; +import java.util.Random; + +/** + * This is a streaming implementation of a cross validation splitter that + * is based on the reservoir idea. It randomly picks training docs while + * respecting the exact training percent. + */ +abstract class AbstractReservoirCrossValidationSplitter implements CrossValidationSplitter { + + protected final int dependentVariableIndex; + private final double samplingRatio; + private final Random random; + + AbstractReservoirCrossValidationSplitter(List fieldNames, String dependentVariable, double trainingPercent, + long randomizeSeed) { + assert trainingPercent >= 1.0 && trainingPercent <= 100.0; + this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); + this.samplingRatio = trainingPercent / 100.0; + this.random = new Random(randomizeSeed); + } + + private static int findDependentVariableIndex(List fieldNames, String dependentVariable) { + int dependentVariableIndex = fieldNames.indexOf(dependentVariable); + if (dependentVariableIndex < 0) { + throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames); + } + return dependentVariableIndex; + } + + @Override + public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { + + if (canBeUsedForTraining(row) == false) { + incrementTestDocs.run(); + return; + } + + SampleInfo sample = getSampleInfo(row); + + // We ensure the target sample count is at least 1 as if the class count + // is too low we might get a target of zero and, thus, no samples of the whole class + long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.classCount); + + // The idea here is that the probability increases as the chances we have to get the target proportion + // for a class decreases. + double p = (double) (targetSampleCount - sample.training) / (sample.classCount - sample.observed); + + boolean isTraining = random.nextDouble() <= p; + + sample.observed++; + if (isTraining) { + sample.training++; + incrementTrainingDocs.run(); + } else { + row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; + incrementTestDocs.run(); + } + } + + private boolean canBeUsedForTraining(String[] row) { + return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE; + } + + protected abstract SampleInfo getSampleInfo(String[] row); + + /** + * Class count, count of docs picked for training, and count of observed + */ + static class SampleInfo { + + private final long classCount; + private long training; + private long observed; + + SampleInfo(long classCount) { + this.classCount = classCount; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java index 7632916c0d4..6701de1bf34 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregationBuilders; import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.bucket.terms.Terms; @@ -41,7 +42,7 @@ public class CrossValidationSplitterFactory { public CrossValidationSplitter create() { if (config.getAnalysis() instanceof Regression) { - return createRandomSplitter(); + return createSingleClassSplitter((Regression) config.getAnalysis()); } if (config.getAnalysis() instanceof Classification) { return createStratifiedSplitter((Classification) config.getAnalysis()); @@ -49,10 +50,23 @@ public class CrossValidationSplitterFactory { return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run(); } - private CrossValidationSplitter createRandomSplitter() { - Regression regression = (Regression) config.getAnalysis(); - return new RandomCrossValidationSplitter( - fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed()); + private CrossValidationSplitter createSingleClassSplitter(Regression regression) { + SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex()) + .setSize(0) + .setAllowPartialSearchResults(false) + .setTrackTotalHits(true) + .setQuery(QueryBuilders.existsQuery(regression.getDependentVariable())); + + try { + SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, + searchRequestBuilder::get); + return new SingleClassReservoirCrossValidationSplitter(fieldNames, regression.getDependentVariable(), + regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value); + } catch (Exception e) { + ParameterizedMessage msg = new ParameterizedMessage("[{}] Error searching total number of training docs", config.getId()); + LOGGER.error(msg, e); + throw new ElasticsearchException(msg.getFormattedMessage(), e); + } } private CrossValidationSplitter createStratifiedSplitter(Classification classification) { @@ -69,12 +83,12 @@ public class CrossValidationSplitterFactory { searchRequestBuilder::get); Aggregations aggs = searchResponse.getAggregations(); Terms terms = aggs.get(aggName); - Map classCardinalities = new HashMap<>(); + Map classCounts = new HashMap<>(); for (Terms.Bucket bucket : terms.getBuckets()) { - classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount()); + classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount()); } - return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities, + return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCounts, classification.getTrainingPercent(), classification.getRandomizeSeed()); } catch (Exception e) { ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java deleted file mode 100644 index 2bd0a209848..00000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; - -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; - -import java.util.List; -import java.util.Random; - -/** - * A cross validation splitter that randomly clears the dependent variable value - * in order to split the dataset in training and test data. - * This relies on the fact that when the dependent variable field - * is empty, then the row is not used for training but only to make predictions. - */ -class RandomCrossValidationSplitter implements CrossValidationSplitter { - - private final int dependentVariableIndex; - private final double trainingPercent; - private final Random random; - private boolean isFirstRow = true; - - RandomCrossValidationSplitter(List fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) { - assert trainingPercent >= 1.0 && trainingPercent <= 100.0; - this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); - this.trainingPercent = trainingPercent; - this.random = new Random(randomizeSeed); - } - - private static int findDependentVariableIndex(List fieldNames, String dependentVariable) { - int dependentVariableIndex = fieldNames.indexOf(dependentVariable); - if (dependentVariableIndex < 0) { - throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames); - } - return dependentVariableIndex; - } - - @Override - public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { - if (canBeUsedForTraining(row) && isPickedForTraining()) { - incrementTrainingDocs.run(); - } else { - row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; - incrementTestDocs.run(); - } - } - - private boolean canBeUsedForTraining(String[] row) { - return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE; - } - - private boolean isPickedForTraining() { - if (isFirstRow) { - // Let's make sure we have at least one training row - isFirstRow = false; - return true; - } - return random.nextDouble() * 100 <= trainingPercent; - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitter.java new file mode 100644 index 00000000000..3159f6d89c0 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitter.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; + +import java.util.List; + +public class SingleClassReservoirCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter { + + private final SampleInfo sampleInfo; + + SingleClassReservoirCrossValidationSplitter(List fieldNames, String dependentVariable, double trainingPercent, + long randomizeSeed, long classCount) { + super(fieldNames, dependentVariable, trainingPercent, randomizeSeed); + sampleInfo = new SampleInfo(classCount); + } + + @Override + protected SampleInfo getSampleInfo(String[] row) { + return sampleInfo; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java index 7b47a1e1fdf..503d6308132 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java @@ -6,89 +6,32 @@ package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; - import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Random; /** * Given a dependent variable, randomly splits the dataset trying * to preserve the proportion of each class in the training sample. */ -public class StratifiedCrossValidationSplitter implements CrossValidationSplitter { +public class StratifiedCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter { - private final int dependentVariableIndex; - private final double samplingRatio; - private final Random random; - private final Map classSamples; + private final Map classSamples; - public StratifiedCrossValidationSplitter(List fieldNames, String dependentVariable, Map classCardinalities, + public StratifiedCrossValidationSplitter(List fieldNames, String dependentVariable, Map classCounts, double trainingPercent, long randomizeSeed) { - assert trainingPercent >= 1.0 && trainingPercent <= 100.0; - this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); - this.samplingRatio = trainingPercent / 100.0; - this.random = new Random(randomizeSeed); + super(fieldNames, dependentVariable, trainingPercent, randomizeSeed); this.classSamples = new HashMap<>(); - classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue()))); - } - - private static int findDependentVariableIndex(List fieldNames, String dependentVariable) { - int dependentVariableIndex = fieldNames.indexOf(dependentVariable); - if (dependentVariableIndex < 0) { - throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames); - } - return dependentVariableIndex; + classCounts.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new SampleInfo(entry.getValue()))); } @Override - public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) { - - if (canBeUsedForTraining(row) == false) { - incrementTestDocs.run(); - return; - } - + protected SampleInfo getSampleInfo(String[] row) { String classValue = row[dependentVariableIndex]; - ClassSample sample = classSamples.get(classValue); + SampleInfo sample = classSamples.get(classValue); if (sample == null) { throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet()); } - - // We ensure the target sample count is at least 1 as if the cardinality - // is too low we might get a target of zero and, thus, no samples of the whole class - long targetSampleCount = (long) Math.max(1.0, samplingRatio * sample.cardinality); - - // The idea here is that the probability increases as the chances we have to get the target proportion - // for a class decreases. - double p = (double) (targetSampleCount - sample.training) / (sample.cardinality - sample.observed); - - boolean isTraining = random.nextDouble() <= p; - - sample.observed++; - if (isTraining) { - sample.training++; - incrementTrainingDocs.run(); - } else { - row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE; - incrementTestDocs.run(); - } - } - - private boolean canBeUsedForTraining(String[] row) { - return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE; - } - - private static class ClassSample { - - private final long cardinality; - private long training; - private long observed; - - private ClassSample(long cardinality) { - this.cardinality = cardinality; - } + return sample; } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitterTests.java similarity index 78% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitterTests.java index 0bbc9d75d8b..fee589df371 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitterTests.java @@ -15,12 +15,10 @@ import java.util.List; import java.util.stream.IntStream; import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.Matchers.both; -import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.lessThan; -public class RandomCrossValidationSplitterTests extends ESTestCase { +public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase { private List fields; private int dependentVariableIndex; @@ -42,7 +40,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CrossValidationSplitter crossValidationSplitter = createSplitter(50.0); + CrossValidationSplitter crossValidationSplitter = createSplitter(50.0, 0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -62,7 +60,7 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CrossValidationSplitter crossValidationSplitter = createSplitter(100.0); + CrossValidationSplitter crossValidationSplitter = createSplitter(100.0, 100L); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -83,14 +81,14 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent); + long rowCount = 1000; int runCount = 20; - int rowsCount = 1000; int[] trainingRowsPerRun = new int[runCount]; for (int testIndex = 0; testIndex < runCount; testIndex++) { + CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent, rowCount); int trainingRows = 0; - for (int i = 0; i < rowsCount; i++) { + for (int i = 0; i < rowCount; i++) { String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { row[fieldIndex] = randomAlphaOfLength(10); @@ -113,23 +111,11 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { } double meanTrainingRows = IntStream.of(trainingRowsPerRun).average().getAsDouble(); - - // Now we need to calculate sensible bounds to assert against. - // We'll use 5 variances which should mean the test only fails once in 7M - // And, because we're doing multiple runs, we'll divide the variance with the number of runs to narrow the bounds - double expectedTrainingRows = trainingFraction * rowsCount; - double variance = rowsCount * (Math.pow(1 - trainingFraction, 2) * trainingFraction - + Math.pow(trainingFraction, 2) * (1 - trainingFraction)); - double lowerBound = expectedTrainingRows - 5 * Math.sqrt(variance / runCount); - double upperBound = expectedTrainingRows + 5 * Math.sqrt(variance / runCount); - - assertThat("Mean training rows [" + meanTrainingRows + "] was not within expected bounds of [" + lowerBound + ", " - + upperBound + "] given training fraction was [" + trainingFraction + "]", - meanTrainingRows, is(both(greaterThan(lowerBound)).and(lessThan(upperBound)))); + assertThat(meanTrainingRows, closeTo(trainingFraction * rowCount, 1.0)); } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CrossValidationSplitter crossValidationSplitter = createSplitter(1.0); + CrossValidationSplitter crossValidationSplitter = createSplitter(1.0, 1); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row @@ -152,8 +138,8 @@ public class RandomCrossValidationSplitterTests extends ESTestCase { assertThat(testDocsCount, equalTo(9L)); } - private RandomCrossValidationSplitter createSplitter(double trainingPercent) { - return new RandomCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed); + private CrossValidationSplitter createSplitter(double trainingPercent, long classCount) { + return new SingleClassReservoirCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount); } private void incrementTrainingDocsCount() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java index 3ef7a13ff2d..57700803d0a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java @@ -32,7 +32,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { private int dependentVariableIndex; private String dependentVariable; private long randomizeSeed; - private Map classCardinalities; + private Map classCounts; private String[] classValuesPerRow; private long trainingDocsCount; private long testDocsCount; @@ -68,10 +68,10 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } } - classCardinalities = new HashMap<>(); - classCardinalities.put("a", classA); - classCardinalities.put("b", classB); - classCardinalities.put("c", classC); + classCounts = new HashMap<>(); + classCounts.put("a", classA); + classCounts.put("b", classB); + classCounts.put("c", classC); } public void testConstructor_GivenMissingDependentVariable() { @@ -143,7 +143,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { Map totalRowsPerClass = new HashMap<>(); Map trainingRowsPerClass = new HashMap<>(); - for (String classValue : classCardinalities.keySet()) { + for (String classValue : classCounts.keySet()) { totalRowsPerClass.put(classValue, 0); trainingRowsPerClass.put(classValue, 0); } @@ -178,14 +178,14 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { // We can assert we're plus/minus 1 from rounding error long expectedTotalTrainingCount = 0; - for (long classCardinality : classCardinalities.values()) { - expectedTotalTrainingCount += trainingFraction * classCardinality; + for (long classCount : classCounts.values()) { + expectedTotalTrainingCount += trainingFraction * classCount; } assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT)); assertThat(trainingDocsCount, greaterThanOrEqualTo(expectedTotalTrainingCount - 2)); assertThat(trainingDocsCount, lessThanOrEqualTo(expectedTotalTrainingCount)); - for (String classValue : classCardinalities.keySet()) { + for (String classValue : classCounts.keySet()) { double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction; int classTrainingCount = trainingRowsPerClass.get(classValue); assertThat((double) classTrainingCount, is(closeTo(expectedClassTrainingCount, 1.0))); @@ -228,12 +228,12 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } } - public void testProcess_GivenTwoClassesWithCardinalityEqualToOne_ShouldUseForTraining() { + public void testProcess_GivenTwoClassesWithCountEqualToOne_ShouldUseForTraining() { dependentVariable = "dep_var"; fields = Arrays.asList(dependentVariable, "feature"); - classCardinalities = new HashMap<>(); - classCardinalities.put("class_a", 1L); - classCardinalities.put("class_b", 1L); + classCounts = new HashMap<>(); + classCounts.put("class_a", 1L); + classCounts.put("class_b", 1L); CrossValidationSplitter splitter = createSplitter(80.0); { @@ -258,7 +258,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } private CrossValidationSplitter createSplitter(double trainingPercent) { - return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed); + return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed); } private void incrementTrainingDocsCount() {