As classification now works for multiple classes, randomly picking training/test data frame rows is not good enough. This commit introduces a stratified cross validation splitter that maintains the proportion of the each class in the dataset in the sample that is used for training the model. Backport of #54087
This commit is contained in:
parent
e006d1f6cf
commit
c141c1dd89
|
@ -54,7 +54,7 @@ public class Classification implements DataFrameAnalysis {
|
|||
/**
|
||||
* The max number of classes classification supports
|
||||
*/
|
||||
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
||||
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
||||
|
||||
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
||||
|
|
|
@ -25,7 +25,6 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||
|
@ -162,7 +161,7 @@ public class AnalyticsProcessManager {
|
|||
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
||||
try {
|
||||
writeHeaderRecord(dataExtractor, process);
|
||||
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(),
|
||||
writeDataRows(dataExtractor, process, config, task.getStatsHolder().getProgressTracker(),
|
||||
task.getStatsHolder().getDataCountsTracker());
|
||||
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
|
||||
DataCounts::documentId);
|
||||
|
@ -214,11 +213,12 @@ public class AnalyticsProcessManager {
|
|||
}
|
||||
}
|
||||
|
||||
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalysis analysis,
|
||||
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException {
|
||||
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
|
||||
DataFrameAnalyticsConfig config, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker)
|
||||
throws IOException {
|
||||
|
||||
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
|
||||
.create(analysis);
|
||||
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(client, config, dataExtractor.getFieldNames())
|
||||
.create();
|
||||
|
||||
// The extra fields are for the doc hash and the control field (should be an empty string)
|
||||
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
|
||||
|
@ -324,7 +324,8 @@ public class AnalyticsProcessManager {
|
|||
);
|
||||
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
|
||||
|
||||
LOGGER.debug("[{}] Refreshing indices {}", jobId, Arrays.toString(refreshRequest.indices()));
|
||||
LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
|
||||
jobId, Arrays.toString(refreshRequest.indices())));
|
||||
|
||||
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
|
||||
client.admin().indices().refresh(refreshRequest).actionGet();
|
||||
|
|
|
@ -5,32 +5,81 @@
|
|||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||
import org.elasticsearch.search.aggregations.Aggregations;
|
||||
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class CrossValidationSplitterFactory {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);
|
||||
|
||||
private final Client client;
|
||||
private final DataFrameAnalyticsConfig config;
|
||||
private final List<String> fieldNames;
|
||||
|
||||
public CrossValidationSplitterFactory(List<String> fieldNames) {
|
||||
public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.config = Objects.requireNonNull(config);
|
||||
this.fieldNames = Objects.requireNonNull(fieldNames);
|
||||
}
|
||||
|
||||
public CrossValidationSplitter create(DataFrameAnalysis analysis) {
|
||||
if (analysis instanceof Regression) {
|
||||
Regression regression = (Regression) analysis;
|
||||
return new RandomCrossValidationSplitter(
|
||||
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
|
||||
public CrossValidationSplitter create() {
|
||||
if (config.getAnalysis() instanceof Regression) {
|
||||
return createRandomSplitter();
|
||||
}
|
||||
if (analysis instanceof Classification) {
|
||||
Classification classification = (Classification) analysis;
|
||||
return new RandomCrossValidationSplitter(
|
||||
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
|
||||
if (config.getAnalysis() instanceof Classification) {
|
||||
return createStratifiedSplitter((Classification) config.getAnalysis());
|
||||
}
|
||||
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createRandomSplitter() {
|
||||
Regression regression = (Regression) config.getAnalysis();
|
||||
return new RandomCrossValidationSplitter(
|
||||
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
|
||||
String aggName = "dependent_variable_terms";
|
||||
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
|
||||
.setSize(0)
|
||||
.setAllowPartialSearchResults(false)
|
||||
.addAggregation(AggregationBuilders.terms(aggName)
|
||||
.field(classification.getDependentVariable())
|
||||
.size(Classification.MAX_DEPENDENT_VARIABLE_CARDINALITY));
|
||||
|
||||
try {
|
||||
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
|
||||
searchRequestBuilder::get);
|
||||
Aggregations aggs = searchResponse.getAggregations();
|
||||
Terms terms = aggs.get(aggName);
|
||||
Map<String, Long> classCardinalities = new HashMap<>();
|
||||
for (Terms.Bucket bucket : terms.getBuckets()) {
|
||||
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
|
||||
}
|
||||
|
||||
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities,
|
||||
classification.getTrainingPercent(), classification.getRandomizeSeed());
|
||||
} catch (Exception e) {
|
||||
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
|
||||
LOGGER.error(msg, e);
|
||||
throw new ElasticsearchException(msg.getFormattedMessage(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,19 +25,19 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter {
|
|||
private boolean isFirstRow = true;
|
||||
|
||||
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
|
||||
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.trainingPercent = trainingPercent;
|
||||
this.random = new Random(randomizeSeed);
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||
for (int i = 0; i < fieldNames.size(); i++) {
|
||||
if (fieldNames.get(i).equals(dependentVariable)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
|
||||
if (dependentVariableIndex < 0) {
|
||||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||
}
|
||||
return dependentVariableIndex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* Given a dependent variable, randomly splits the dataset trying
|
||||
* to preserve the proportion of each class in the training sample.
|
||||
*/
|
||||
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter {
|
||||
|
||||
private final int dependentVariableIndex;
|
||||
private final double samplingRatio;
|
||||
private final Random random;
|
||||
private final Map<String, ClassSample> classSamples;
|
||||
|
||||
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities,
|
||||
double trainingPercent, long randomizeSeed) {
|
||||
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.samplingRatio = trainingPercent / 100.0;
|
||||
this.random = new Random(randomizeSeed);
|
||||
this.classSamples = new HashMap<>();
|
||||
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue())));
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
|
||||
if (dependentVariableIndex < 0) {
|
||||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||
}
|
||||
return dependentVariableIndex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||
|
||||
if (canBeUsedForTraining(row) == false) {
|
||||
incrementTestDocs.run();
|
||||
return;
|
||||
}
|
||||
|
||||
String classValue = row[dependentVariableIndex];
|
||||
ClassSample sample = classSamples.get(classValue);
|
||||
if (sample == null) {
|
||||
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
|
||||
}
|
||||
|
||||
// The idea here is that the probability increases as the chances we have to get the target proportion
|
||||
// for a class decreases.
|
||||
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);
|
||||
|
||||
boolean isTraining = random.nextDouble() <= p;
|
||||
|
||||
sample.observed++;
|
||||
if (isTraining) {
|
||||
sample.training++;
|
||||
incrementTrainingDocs.run();
|
||||
} else {
|
||||
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
|
||||
incrementTestDocs.run();
|
||||
}
|
||||
}
|
||||
|
||||
private boolean canBeUsedForTraining(String[] row) {
|
||||
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
|
||||
}
|
||||
|
||||
private static class ClassSample {
|
||||
|
||||
private final long cardinality;
|
||||
private long training;
|
||||
private long observed;
|
||||
|
||||
private ClassSample(long cardinality) {
|
||||
this.cardinality = cardinality;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,239 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
|
||||
public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
||||
|
||||
private static final int ROWS_COUNT = 500;
|
||||
|
||||
private List<String> fields;
|
||||
private int dependentVariableIndex;
|
||||
private String dependentVariable;
|
||||
private long randomizeSeed;
|
||||
private Map<String, Long> classCardinalities;
|
||||
private String[] classValuesPerRow;
|
||||
private long trainingDocsCount;
|
||||
private long testDocsCount;
|
||||
|
||||
@Before
|
||||
public void setUpTests() {
|
||||
int fieldCount = randomIntBetween(1, 5);
|
||||
fields = new ArrayList<>(fieldCount);
|
||||
for (int i = 0; i < fieldCount; i++) {
|
||||
fields.add(randomAlphaOfLength(10));
|
||||
}
|
||||
dependentVariableIndex = randomIntBetween(0, fieldCount - 1);
|
||||
dependentVariable = fields.get(dependentVariableIndex);
|
||||
randomizeSeed = randomLong();
|
||||
|
||||
long classA = 0;
|
||||
long classB = 0;
|
||||
long classC = 0;
|
||||
|
||||
|
||||
classValuesPerRow = new String[ROWS_COUNT];
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
double randomDouble = randomDoubleBetween(0.0, 1.0, true);
|
||||
if (randomDouble < 0.2) {
|
||||
classValuesPerRow[i] = "a";
|
||||
classA++;
|
||||
} else if (randomDouble < 0.5) {
|
||||
classValuesPerRow[i] = "b";
|
||||
classB++;
|
||||
} else {
|
||||
classValuesPerRow[i] = "c";
|
||||
classC++;
|
||||
}
|
||||
}
|
||||
|
||||
classCardinalities = new HashMap<>();
|
||||
classCardinalities.put("a", classA);
|
||||
classCardinalities.put("b", classB);
|
||||
classCardinalities.put("c", classC);
|
||||
}
|
||||
|
||||
public void testConstructor_GivenMissingDependentVariable() {
|
||||
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedCrossValidationSplitter(
|
||||
Collections.emptyList(), "foo", Collections.emptyMap(), 100.0, 0));
|
||||
assertThat(e.getMessage(), equalTo("Could not find dependent variable [foo] in fields []"));
|
||||
}
|
||||
|
||||
public void testProcess_GivenUnknownClass() {
|
||||
CrossValidationSplitter splitter = createSplitter(100.0);
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
row[fieldIndex] = randomAlphaOfLength(5);
|
||||
}
|
||||
row[dependentVariableIndex] = "unknown_class";
|
||||
|
||||
IllegalStateException e = expectThrows(IllegalStateException.class,
|
||||
() -> splitter.process(row, this::incrementTrainingDocsCount, this::incrementTestDocsCount));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("Unknown class [unknown_class]; expected one of [a, b, c]"));
|
||||
}
|
||||
|
||||
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
||||
CrossValidationSplitter splitter = createSplitter(50.0);
|
||||
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
String value = fieldIndex == dependentVariableIndex ? DataFrameDataExtractor.NULL_VALUE : randomAlphaOfLength(10);
|
||||
row[fieldIndex] = value;
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||
|
||||
// As all these rows have no dependent variable value, they're not for training and should be unaffected
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
assertThat(trainingDocsCount, equalTo(0L));
|
||||
assertThat(testDocsCount, equalTo(500L));
|
||||
}
|
||||
|
||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||
CrossValidationSplitter splitter = createSplitter(100.0);
|
||||
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
|
||||
row[fieldIndex] = value;
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||
|
||||
// As training percent is 100 all rows should be unaffected
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
assertThat(trainingDocsCount, equalTo(500L));
|
||||
assertThat(testDocsCount, equalTo(0L));
|
||||
}
|
||||
|
||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
||||
// We don't go too low here to avoid flakiness
|
||||
double trainingPercent = randomDoubleBetween(50.0, 100.0, true);
|
||||
|
||||
CrossValidationSplitter splitter = createSplitter(trainingPercent);
|
||||
|
||||
Map<String, Integer> totalRowsPerClass = new HashMap<>();
|
||||
Map<String, Integer> trainingRowsPerClass = new HashMap<>();
|
||||
|
||||
for (String classValue : classCardinalities.keySet()) {
|
||||
totalRowsPerClass.put(classValue, 0);
|
||||
trainingRowsPerClass.put(classValue, 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
|
||||
row[fieldIndex] = value;
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
if (fieldIndex != dependentVariableIndex) {
|
||||
assertThat(processedRow[fieldIndex], equalTo(row[fieldIndex]));
|
||||
}
|
||||
}
|
||||
|
||||
String classValue = row[dependentVariableIndex];
|
||||
totalRowsPerClass.compute(classValue, (k, v) -> v + 1);
|
||||
|
||||
if (DataFrameDataExtractor.NULL_VALUE.equals(processedRow[dependentVariableIndex]) == false) {
|
||||
assertThat(processedRow[dependentVariableIndex], equalTo(row[dependentVariableIndex]));
|
||||
trainingRowsPerClass.compute(classValue, (k, v) -> v + 1);
|
||||
}
|
||||
}
|
||||
|
||||
double trainingFraction = trainingPercent / 100;
|
||||
|
||||
// We can assert we're plus/minus 1 from rounding error
|
||||
|
||||
double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction;
|
||||
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
|
||||
assertThat(trainingDocsCount, greaterThanOrEqualTo((long) Math.floor(expectedTotalTrainingCount - 1)));
|
||||
assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount + 1)));
|
||||
|
||||
for (String classValue : classCardinalities.keySet()) {
|
||||
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
|
||||
int classTrainingCount = trainingRowsPerClass.get(classValue);
|
||||
assertThat((double) classTrainingCount, is(closeTo(expectedClassTrainingCount, 1.0)));
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_SelectsTrainingRowsUniformly() {
|
||||
double trainingPercent = 50.0;
|
||||
int runCount = 500;
|
||||
|
||||
int[] trainingCountPerRow = new int[ROWS_COUNT];
|
||||
|
||||
for (int run = 0; run < runCount; run++) {
|
||||
|
||||
randomizeSeed = randomLong();
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent);
|
||||
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
|
||||
row[fieldIndex] = value;
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||
|
||||
if (processedRow[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE) {
|
||||
trainingCountPerRow[i]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We expect each data row to be selected uniformly.
|
||||
// Thus the fraction of the row count where it's selected for training against the number of runs
|
||||
// should be close to the training percent, which is set to 0.5
|
||||
for (int rowTrainingCount : trainingCountPerRow) {
|
||||
double meanCount = rowTrainingCount / (double) runCount;
|
||||
assertThat(meanCount, is(closeTo(0.5, 0.1)));
|
||||
}
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createSplitter(double trainingPercent) {
|
||||
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed);
|
||||
}
|
||||
|
||||
private void incrementTrainingDocsCount() {
|
||||
trainingDocsCount++;
|
||||
}
|
||||
|
||||
private void incrementTestDocsCount() {
|
||||
testDocsCount++;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue