As classification now works for multiple classes, randomly picking training/test data frame rows is not good enough. This commit introduces a stratified cross validation splitter that maintains the proportion of the each class in the dataset in the sample that is used for training the model. Backport of #54087
This commit is contained in:
parent
e006d1f6cf
commit
c141c1dd89
|
@ -54,7 +54,7 @@ public class Classification implements DataFrameAnalysis {
|
||||||
/**
|
/**
|
||||||
* The max number of classes classification supports
|
* The max number of classes classification supports
|
||||||
*/
|
*/
|
||||||
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
|
||||||
|
|
||||||
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
|
||||||
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
|
||||||
|
|
|
@ -25,7 +25,6 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.xpack.core.ClientHelper;
|
import org.elasticsearch.xpack.core.ClientHelper;
|
||||||
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
|
||||||
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
|
||||||
|
@ -162,7 +161,7 @@ public class AnalyticsProcessManager {
|
||||||
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
|
||||||
try {
|
try {
|
||||||
writeHeaderRecord(dataExtractor, process);
|
writeHeaderRecord(dataExtractor, process);
|
||||||
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(),
|
writeDataRows(dataExtractor, process, config, task.getStatsHolder().getProgressTracker(),
|
||||||
task.getStatsHolder().getDataCountsTracker());
|
task.getStatsHolder().getDataCountsTracker());
|
||||||
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
|
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
|
||||||
DataCounts::documentId);
|
DataCounts::documentId);
|
||||||
|
@ -214,11 +213,12 @@ public class AnalyticsProcessManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalysis analysis,
|
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
|
||||||
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException {
|
DataFrameAnalyticsConfig config, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker)
|
||||||
|
throws IOException {
|
||||||
|
|
||||||
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
|
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(client, config, dataExtractor.getFieldNames())
|
||||||
.create(analysis);
|
.create();
|
||||||
|
|
||||||
// The extra fields are for the doc hash and the control field (should be an empty string)
|
// The extra fields are for the doc hash and the control field (should be an empty string)
|
||||||
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
|
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
|
||||||
|
@ -324,7 +324,8 @@ public class AnalyticsProcessManager {
|
||||||
);
|
);
|
||||||
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
|
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
|
||||||
|
|
||||||
LOGGER.debug("[{}] Refreshing indices {}", jobId, Arrays.toString(refreshRequest.indices()));
|
LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
|
||||||
|
jobId, Arrays.toString(refreshRequest.indices())));
|
||||||
|
|
||||||
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
|
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
|
||||||
client.admin().indices().refresh(refreshRequest).actionGet();
|
client.admin().indices().refresh(refreshRequest).actionGet();
|
||||||
|
|
|
@ -5,32 +5,81 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||||
|
|
||||||
|
import org.apache.logging.log4j.LogManager;
|
||||||
|
import org.apache.logging.log4j.Logger;
|
||||||
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
|
import org.elasticsearch.ElasticsearchException;
|
||||||
|
import org.elasticsearch.action.search.SearchRequestBuilder;
|
||||||
|
import org.elasticsearch.action.search.SearchResponse;
|
||||||
|
import org.elasticsearch.client.Client;
|
||||||
|
import org.elasticsearch.search.aggregations.AggregationBuilders;
|
||||||
|
import org.elasticsearch.search.aggregations.Aggregations;
|
||||||
|
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
|
||||||
|
import org.elasticsearch.xpack.core.ClientHelper;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public class CrossValidationSplitterFactory {
|
public class CrossValidationSplitterFactory {
|
||||||
|
|
||||||
|
private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);
|
||||||
|
|
||||||
|
private final Client client;
|
||||||
|
private final DataFrameAnalyticsConfig config;
|
||||||
private final List<String> fieldNames;
|
private final List<String> fieldNames;
|
||||||
|
|
||||||
public CrossValidationSplitterFactory(List<String> fieldNames) {
|
public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
|
||||||
|
this.client = Objects.requireNonNull(client);
|
||||||
|
this.config = Objects.requireNonNull(config);
|
||||||
this.fieldNames = Objects.requireNonNull(fieldNames);
|
this.fieldNames = Objects.requireNonNull(fieldNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CrossValidationSplitter create(DataFrameAnalysis analysis) {
|
public CrossValidationSplitter create() {
|
||||||
if (analysis instanceof Regression) {
|
if (config.getAnalysis() instanceof Regression) {
|
||||||
Regression regression = (Regression) analysis;
|
return createRandomSplitter();
|
||||||
return new RandomCrossValidationSplitter(
|
|
||||||
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
|
|
||||||
}
|
}
|
||||||
if (analysis instanceof Classification) {
|
if (config.getAnalysis() instanceof Classification) {
|
||||||
Classification classification = (Classification) analysis;
|
return createStratifiedSplitter((Classification) config.getAnalysis());
|
||||||
return new RandomCrossValidationSplitter(
|
|
||||||
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
|
|
||||||
}
|
}
|
||||||
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
|
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private CrossValidationSplitter createRandomSplitter() {
|
||||||
|
Regression regression = (Regression) config.getAnalysis();
|
||||||
|
return new RandomCrossValidationSplitter(
|
||||||
|
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
|
||||||
|
}
|
||||||
|
|
||||||
|
private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
|
||||||
|
String aggName = "dependent_variable_terms";
|
||||||
|
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
|
||||||
|
.setSize(0)
|
||||||
|
.setAllowPartialSearchResults(false)
|
||||||
|
.addAggregation(AggregationBuilders.terms(aggName)
|
||||||
|
.field(classification.getDependentVariable())
|
||||||
|
.size(Classification.MAX_DEPENDENT_VARIABLE_CARDINALITY));
|
||||||
|
|
||||||
|
try {
|
||||||
|
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
|
||||||
|
searchRequestBuilder::get);
|
||||||
|
Aggregations aggs = searchResponse.getAggregations();
|
||||||
|
Terms terms = aggs.get(aggName);
|
||||||
|
Map<String, Long> classCardinalities = new HashMap<>();
|
||||||
|
for (Terms.Bucket bucket : terms.getBuckets()) {
|
||||||
|
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
|
||||||
|
}
|
||||||
|
|
||||||
|
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities,
|
||||||
|
classification.getTrainingPercent(), classification.getRandomizeSeed());
|
||||||
|
} catch (Exception e) {
|
||||||
|
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
|
||||||
|
LOGGER.error(msg, e);
|
||||||
|
throw new ElasticsearchException(msg.getFormattedMessage(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,19 +25,19 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter {
|
||||||
private boolean isFirstRow = true;
|
private boolean isFirstRow = true;
|
||||||
|
|
||||||
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
|
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
|
||||||
|
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||||
this.trainingPercent = trainingPercent;
|
this.trainingPercent = trainingPercent;
|
||||||
this.random = new Random(randomizeSeed);
|
this.random = new Random(randomizeSeed);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||||
for (int i = 0; i < fieldNames.size(); i++) {
|
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
|
||||||
if (fieldNames.get(i).equals(dependentVariable)) {
|
if (dependentVariableIndex < 0) {
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||||
}
|
}
|
||||||
|
return dependentVariableIndex;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||||
|
|
||||||
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Given a dependent variable, randomly splits the dataset trying
|
||||||
|
* to preserve the proportion of each class in the training sample.
|
||||||
|
*/
|
||||||
|
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter {
|
||||||
|
|
||||||
|
private final int dependentVariableIndex;
|
||||||
|
private final double samplingRatio;
|
||||||
|
private final Random random;
|
||||||
|
private final Map<String, ClassSample> classSamples;
|
||||||
|
|
||||||
|
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities,
|
||||||
|
double trainingPercent, long randomizeSeed) {
|
||||||
|
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||||
|
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||||
|
this.samplingRatio = trainingPercent / 100.0;
|
||||||
|
this.random = new Random(randomizeSeed);
|
||||||
|
this.classSamples = new HashMap<>();
|
||||||
|
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue())));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||||
|
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
|
||||||
|
if (dependentVariableIndex < 0) {
|
||||||
|
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||||
|
}
|
||||||
|
return dependentVariableIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
|
||||||
|
|
||||||
|
if (canBeUsedForTraining(row) == false) {
|
||||||
|
incrementTestDocs.run();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
String classValue = row[dependentVariableIndex];
|
||||||
|
ClassSample sample = classSamples.get(classValue);
|
||||||
|
if (sample == null) {
|
||||||
|
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
|
||||||
|
}
|
||||||
|
|
||||||
|
// The idea here is that the probability increases as the chances we have to get the target proportion
|
||||||
|
// for a class decreases.
|
||||||
|
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);
|
||||||
|
|
||||||
|
boolean isTraining = random.nextDouble() <= p;
|
||||||
|
|
||||||
|
sample.observed++;
|
||||||
|
if (isTraining) {
|
||||||
|
sample.training++;
|
||||||
|
incrementTrainingDocs.run();
|
||||||
|
} else {
|
||||||
|
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
|
||||||
|
incrementTestDocs.run();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean canBeUsedForTraining(String[] row) {
|
||||||
|
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class ClassSample {
|
||||||
|
|
||||||
|
private final long cardinality;
|
||||||
|
private long training;
|
||||||
|
private long observed;
|
||||||
|
|
||||||
|
private ClassSample(long cardinality) {
|
||||||
|
this.cardinality = cardinality;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,239 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchException;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.hamcrest.CoreMatchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
|
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||||
|
|
||||||
|
public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
||||||
|
|
||||||
|
private static final int ROWS_COUNT = 500;
|
||||||
|
|
||||||
|
private List<String> fields;
|
||||||
|
private int dependentVariableIndex;
|
||||||
|
private String dependentVariable;
|
||||||
|
private long randomizeSeed;
|
||||||
|
private Map<String, Long> classCardinalities;
|
||||||
|
private String[] classValuesPerRow;
|
||||||
|
private long trainingDocsCount;
|
||||||
|
private long testDocsCount;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUpTests() {
|
||||||
|
int fieldCount = randomIntBetween(1, 5);
|
||||||
|
fields = new ArrayList<>(fieldCount);
|
||||||
|
for (int i = 0; i < fieldCount; i++) {
|
||||||
|
fields.add(randomAlphaOfLength(10));
|
||||||
|
}
|
||||||
|
dependentVariableIndex = randomIntBetween(0, fieldCount - 1);
|
||||||
|
dependentVariable = fields.get(dependentVariableIndex);
|
||||||
|
randomizeSeed = randomLong();
|
||||||
|
|
||||||
|
long classA = 0;
|
||||||
|
long classB = 0;
|
||||||
|
long classC = 0;
|
||||||
|
|
||||||
|
|
||||||
|
classValuesPerRow = new String[ROWS_COUNT];
|
||||||
|
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||||
|
double randomDouble = randomDoubleBetween(0.0, 1.0, true);
|
||||||
|
if (randomDouble < 0.2) {
|
||||||
|
classValuesPerRow[i] = "a";
|
||||||
|
classA++;
|
||||||
|
} else if (randomDouble < 0.5) {
|
||||||
|
classValuesPerRow[i] = "b";
|
||||||
|
classB++;
|
||||||
|
} else {
|
||||||
|
classValuesPerRow[i] = "c";
|
||||||
|
classC++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
classCardinalities = new HashMap<>();
|
||||||
|
classCardinalities.put("a", classA);
|
||||||
|
classCardinalities.put("b", classB);
|
||||||
|
classCardinalities.put("c", classC);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testConstructor_GivenMissingDependentVariable() {
|
||||||
|
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedCrossValidationSplitter(
|
||||||
|
Collections.emptyList(), "foo", Collections.emptyMap(), 100.0, 0));
|
||||||
|
assertThat(e.getMessage(), equalTo("Could not find dependent variable [foo] in fields []"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcess_GivenUnknownClass() {
|
||||||
|
CrossValidationSplitter splitter = createSplitter(100.0);
|
||||||
|
String[] row = new String[fields.size()];
|
||||||
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
|
row[fieldIndex] = randomAlphaOfLength(5);
|
||||||
|
}
|
||||||
|
row[dependentVariableIndex] = "unknown_class";
|
||||||
|
|
||||||
|
IllegalStateException e = expectThrows(IllegalStateException.class,
|
||||||
|
() -> splitter.process(row, this::incrementTrainingDocsCount, this::incrementTestDocsCount));
|
||||||
|
|
||||||
|
assertThat(e.getMessage(), equalTo("Unknown class [unknown_class]; expected one of [a, b, c]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
||||||
|
CrossValidationSplitter splitter = createSplitter(50.0);
|
||||||
|
|
||||||
|
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||||
|
String[] row = new String[fields.size()];
|
||||||
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
|
String value = fieldIndex == dependentVariableIndex ? DataFrameDataExtractor.NULL_VALUE : randomAlphaOfLength(10);
|
||||||
|
row[fieldIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
|
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
|
// As all these rows have no dependent variable value, they're not for training and should be unaffected
|
||||||
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||||
|
}
|
||||||
|
assertThat(trainingDocsCount, equalTo(0L));
|
||||||
|
assertThat(testDocsCount, equalTo(500L));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||||
|
CrossValidationSplitter splitter = createSplitter(100.0);
|
||||||
|
|
||||||
|
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||||
|
String[] row = new String[fields.size()];
|
||||||
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
|
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
|
||||||
|
row[fieldIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
|
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
|
// As training percent is 100 all rows should be unaffected
|
||||||
|
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||||
|
}
|
||||||
|
assertThat(trainingDocsCount, equalTo(500L));
|
||||||
|
assertThat(testDocsCount, equalTo(0L));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
||||||
|
// We don't go too low here to avoid flakiness
|
||||||
|
double trainingPercent = randomDoubleBetween(50.0, 100.0, true);
|
||||||
|
|
||||||
|
CrossValidationSplitter splitter = createSplitter(trainingPercent);
|
||||||
|
|
||||||
|
Map<String, Integer> totalRowsPerClass = new HashMap<>();
|
||||||
|
Map<String, Integer> trainingRowsPerClass = new HashMap<>();
|
||||||
|
|
||||||
|
for (String classValue : classCardinalities.keySet()) {
|
||||||
|
totalRowsPerClass.put(classValue, 0);
|
||||||
|
trainingRowsPerClass.put(classValue, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||||
|
String[] row = new String[fields.size()];
|
||||||
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
|
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
|
||||||
|
row[fieldIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
|
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
|
if (fieldIndex != dependentVariableIndex) {
|
||||||
|
assertThat(processedRow[fieldIndex], equalTo(row[fieldIndex]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String classValue = row[dependentVariableIndex];
|
||||||
|
totalRowsPerClass.compute(classValue, (k, v) -> v + 1);
|
||||||
|
|
||||||
|
if (DataFrameDataExtractor.NULL_VALUE.equals(processedRow[dependentVariableIndex]) == false) {
|
||||||
|
assertThat(processedRow[dependentVariableIndex], equalTo(row[dependentVariableIndex]));
|
||||||
|
trainingRowsPerClass.compute(classValue, (k, v) -> v + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double trainingFraction = trainingPercent / 100;
|
||||||
|
|
||||||
|
// We can assert we're plus/minus 1 from rounding error
|
||||||
|
|
||||||
|
double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction;
|
||||||
|
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
|
||||||
|
assertThat(trainingDocsCount, greaterThanOrEqualTo((long) Math.floor(expectedTotalTrainingCount - 1)));
|
||||||
|
assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount + 1)));
|
||||||
|
|
||||||
|
for (String classValue : classCardinalities.keySet()) {
|
||||||
|
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
|
||||||
|
int classTrainingCount = trainingRowsPerClass.get(classValue);
|
||||||
|
assertThat((double) classTrainingCount, is(closeTo(expectedClassTrainingCount, 1.0)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testProcess_SelectsTrainingRowsUniformly() {
|
||||||
|
double trainingPercent = 50.0;
|
||||||
|
int runCount = 500;
|
||||||
|
|
||||||
|
int[] trainingCountPerRow = new int[ROWS_COUNT];
|
||||||
|
|
||||||
|
for (int run = 0; run < runCount; run++) {
|
||||||
|
|
||||||
|
randomizeSeed = randomLong();
|
||||||
|
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent);
|
||||||
|
|
||||||
|
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||||
|
String[] row = new String[fields.size()];
|
||||||
|
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||||
|
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
|
||||||
|
row[fieldIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||||
|
crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
|
||||||
|
|
||||||
|
if (processedRow[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE) {
|
||||||
|
trainingCountPerRow[i]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We expect each data row to be selected uniformly.
|
||||||
|
// Thus the fraction of the row count where it's selected for training against the number of runs
|
||||||
|
// should be close to the training percent, which is set to 0.5
|
||||||
|
for (int rowTrainingCount : trainingCountPerRow) {
|
||||||
|
double meanCount = rowTrainingCount / (double) runCount;
|
||||||
|
assertThat(meanCount, is(closeTo(0.5, 0.1)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private CrossValidationSplitter createSplitter(double trainingPercent) {
|
||||||
|
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void incrementTrainingDocsCount() {
|
||||||
|
trainingDocsCount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void incrementTestDocsCount() {
|
||||||
|
testDocsCount++;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue