[7.x][ML] Stratified cross validation split for classification (#54087) (#54104)

As classification now works for multiple classes, randomly
picking training/test data frame rows is not good enough.
This commit introduces a stratified cross validation splitter
that maintains the proportion of the each class in the dataset
in the sample that is used for training the model.

Backport of #54087
This commit is contained in:
Dimitris Athanasiou 2020-03-24 18:47:36 +02:00 committed by GitHub
parent e006d1f6cf
commit c141c1dd89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 403 additions and 24 deletions

View File

@ -54,7 +54,7 @@ public class Classification implements DataFrameAnalysis {
/** /**
* The max number of classes classification supports * The max number of classes classification supports
*/ */
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30; public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) { private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>( ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(

View File

@ -25,7 +25,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlStatsIndex; import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts; import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@ -162,7 +161,7 @@ public class AnalyticsProcessManager {
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get(); AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
try { try {
writeHeaderRecord(dataExtractor, process); writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(), writeDataRows(dataExtractor, process, config, task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker()); task.getStatsHolder().getDataCountsTracker());
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()), processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
DataCounts::documentId); DataCounts::documentId);
@ -214,11 +213,12 @@ public class AnalyticsProcessManager {
} }
} }
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalysis analysis, private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException { DataFrameAnalyticsConfig config, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker)
throws IOException {
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames()) CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(client, config, dataExtractor.getFieldNames())
.create(analysis); .create();
// The extra fields are for the doc hash and the control field (should be an empty string) // The extra fields are for the doc hash and the control field (should be an empty string)
String[] record = new String[dataExtractor.getFieldNames().size() + 2]; String[] record = new String[dataExtractor.getFieldNames().size() + 2];
@ -324,7 +324,8 @@ public class AnalyticsProcessManager {
); );
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
LOGGER.debug("[{}] Refreshing indices {}", jobId, Arrays.toString(refreshRequest.indices())); LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
jobId, Arrays.toString(refreshRequest.indices())));
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
client.admin().indices().refresh(refreshRequest).actionGet(); client.admin().indices().refresh(refreshRequest).actionGet();

View File

@ -5,32 +5,81 @@
*/ */
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
public class CrossValidationSplitterFactory { public class CrossValidationSplitterFactory {
private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);
private final Client client;
private final DataFrameAnalyticsConfig config;
private final List<String> fieldNames; private final List<String> fieldNames;
public CrossValidationSplitterFactory(List<String> fieldNames) { public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
this.client = Objects.requireNonNull(client);
this.config = Objects.requireNonNull(config);
this.fieldNames = Objects.requireNonNull(fieldNames); this.fieldNames = Objects.requireNonNull(fieldNames);
} }
public CrossValidationSplitter create(DataFrameAnalysis analysis) { public CrossValidationSplitter create() {
if (analysis instanceof Regression) { if (config.getAnalysis() instanceof Regression) {
Regression regression = (Regression) analysis; return createRandomSplitter();
return new RandomCrossValidationSplitter(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
} }
if (analysis instanceof Classification) { if (config.getAnalysis() instanceof Classification) {
Classification classification = (Classification) analysis; return createStratifiedSplitter((Classification) config.getAnalysis());
return new RandomCrossValidationSplitter(
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
} }
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run(); return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
} }
private CrossValidationSplitter createRandomSplitter() {
Regression regression = (Regression) config.getAnalysis();
return new RandomCrossValidationSplitter(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
}
private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
String aggName = "dependent_variable_terms";
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
.setSize(0)
.setAllowPartialSearchResults(false)
.addAggregation(AggregationBuilders.terms(aggName)
.field(classification.getDependentVariable())
.size(Classification.MAX_DEPENDENT_VARIABLE_CARDINALITY));
try {
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
searchRequestBuilder::get);
Aggregations aggs = searchResponse.getAggregations();
Terms terms = aggs.get(aggName);
Map<String, Long> classCardinalities = new HashMap<>();
for (Terms.Bucket bucket : terms.getBuckets()) {
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
}
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities,
classification.getTrainingPercent(), classification.getRandomizeSeed());
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
LOGGER.error(msg, e);
throw new ElasticsearchException(msg.getFormattedMessage(), e);
}
}
} }

View File

@ -25,18 +25,18 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter {
private boolean isFirstRow = true; private boolean isFirstRow = true;
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) { RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.trainingPercent = trainingPercent; this.trainingPercent = trainingPercent;
this.random = new Random(randomizeSeed); this.random = new Random(randomizeSeed);
} }
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) { private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
for (int i = 0; i < fieldNames.size(); i++) { int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (fieldNames.get(i).equals(dependentVariable)) { if (dependentVariableIndex < 0) {
return i; throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
} }
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames); return dependentVariableIndex;
} }
@Override @Override

View File

@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* Given a dependent variable, randomly splits the dataset trying
* to preserve the proportion of each class in the training sample.
*/
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter {
private final int dependentVariableIndex;
private final double samplingRatio;
private final Random random;
private final Map<String, ClassSample> classSamples;
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities,
double trainingPercent, long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.samplingRatio = trainingPercent / 100.0;
this.random = new Random(randomizeSeed);
this.classSamples = new HashMap<>();
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue())));
}
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
if (dependentVariableIndex < 0) {
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
return dependentVariableIndex;
}
@Override
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
if (canBeUsedForTraining(row) == false) {
incrementTestDocs.run();
return;
}
String classValue = row[dependentVariableIndex];
ClassSample sample = classSamples.get(classValue);
if (sample == null) {
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
}
// The idea here is that the probability increases as the chances we have to get the target proportion
// for a class decreases.
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);
boolean isTraining = random.nextDouble() <= p;
sample.observed++;
if (isTraining) {
sample.training++;
incrementTrainingDocs.run();
} else {
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
incrementTestDocs.run();
}
}
private boolean canBeUsedForTraining(String[] row) {
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
}
private static class ClassSample {
private final long cardinality;
private long training;
private long observed;
private ClassSample(long cardinality) {
this.cardinality = cardinality;
}
}
}

View File

@ -0,0 +1,239 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.junit.Before;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class StratifiedCrossValidationSplitterTests extends ESTestCase {
private static final int ROWS_COUNT = 500;
private List<String> fields;
private int dependentVariableIndex;
private String dependentVariable;
private long randomizeSeed;
private Map<String, Long> classCardinalities;
private String[] classValuesPerRow;
private long trainingDocsCount;
private long testDocsCount;
@Before
public void setUpTests() {
int fieldCount = randomIntBetween(1, 5);
fields = new ArrayList<>(fieldCount);
for (int i = 0; i < fieldCount; i++) {
fields.add(randomAlphaOfLength(10));
}
dependentVariableIndex = randomIntBetween(0, fieldCount - 1);
dependentVariable = fields.get(dependentVariableIndex);
randomizeSeed = randomLong();
long classA = 0;
long classB = 0;
long classC = 0;
classValuesPerRow = new String[ROWS_COUNT];
for (int i = 0; i < classValuesPerRow.length; i++) {
double randomDouble = randomDoubleBetween(0.0, 1.0, true);
if (randomDouble < 0.2) {
classValuesPerRow[i] = "a";
classA++;
} else if (randomDouble < 0.5) {
classValuesPerRow[i] = "b";
classB++;
} else {
classValuesPerRow[i] = "c";
classC++;
}
}
classCardinalities = new HashMap<>();
classCardinalities.put("a", classA);
classCardinalities.put("b", classB);
classCardinalities.put("c", classC);
}
public void testConstructor_GivenMissingDependentVariable() {
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedCrossValidationSplitter(
Collections.emptyList(), "foo", Collections.emptyMap(), 100.0, 0));
assertThat(e.getMessage(), equalTo("Could not find dependent variable [foo] in fields []"));
}
public void testProcess_GivenUnknownClass() {
CrossValidationSplitter splitter = createSplitter(100.0);
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
row[fieldIndex] = randomAlphaOfLength(5);
}
row[dependentVariableIndex] = "unknown_class";
IllegalStateException e = expectThrows(IllegalStateException.class,
() -> splitter.process(row, this::incrementTrainingDocsCount, this::incrementTestDocsCount));
assertThat(e.getMessage(), equalTo("Unknown class [unknown_class]; expected one of [a, b, c]"));
}
public void testProcess_GivenRowsWithoutDependentVariableValue() {
CrossValidationSplitter splitter = createSplitter(50.0);
for (int i = 0; i < classValuesPerRow.length; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
String value = fieldIndex == dependentVariableIndex ? DataFrameDataExtractor.NULL_VALUE : randomAlphaOfLength(10);
row[fieldIndex] = value;
}
String[] processedRow = Arrays.copyOf(row, row.length);
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
// As all these rows have no dependent variable value, they're not for training and should be unaffected
assertThat(Arrays.equals(processedRow, row), is(true));
}
assertThat(trainingDocsCount, equalTo(0L));
assertThat(testDocsCount, equalTo(500L));
}
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CrossValidationSplitter splitter = createSplitter(100.0);
for (int i = 0; i < classValuesPerRow.length; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
row[fieldIndex] = value;
}
String[] processedRow = Arrays.copyOf(row, row.length);
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
// As training percent is 100 all rows should be unaffected
assertThat(Arrays.equals(processedRow, row), is(true));
}
assertThat(trainingDocsCount, equalTo(500L));
assertThat(testDocsCount, equalTo(0L));
}
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
// We don't go too low here to avoid flakiness
double trainingPercent = randomDoubleBetween(50.0, 100.0, true);
CrossValidationSplitter splitter = createSplitter(trainingPercent);
Map<String, Integer> totalRowsPerClass = new HashMap<>();
Map<String, Integer> trainingRowsPerClass = new HashMap<>();
for (String classValue : classCardinalities.keySet()) {
totalRowsPerClass.put(classValue, 0);
trainingRowsPerClass.put(classValue, 0);
}
for (int i = 0; i < classValuesPerRow.length; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
row[fieldIndex] = value;
}
String[] processedRow = Arrays.copyOf(row, row.length);
splitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
if (fieldIndex != dependentVariableIndex) {
assertThat(processedRow[fieldIndex], equalTo(row[fieldIndex]));
}
}
String classValue = row[dependentVariableIndex];
totalRowsPerClass.compute(classValue, (k, v) -> v + 1);
if (DataFrameDataExtractor.NULL_VALUE.equals(processedRow[dependentVariableIndex]) == false) {
assertThat(processedRow[dependentVariableIndex], equalTo(row[dependentVariableIndex]));
trainingRowsPerClass.compute(classValue, (k, v) -> v + 1);
}
}
double trainingFraction = trainingPercent / 100;
// We can assert we're plus/minus 1 from rounding error
double expectedTotalTrainingCount = ROWS_COUNT * trainingFraction;
assertThat(trainingDocsCount + testDocsCount, equalTo((long) ROWS_COUNT));
assertThat(trainingDocsCount, greaterThanOrEqualTo((long) Math.floor(expectedTotalTrainingCount - 1)));
assertThat(trainingDocsCount, lessThanOrEqualTo((long) Math.ceil(expectedTotalTrainingCount + 1)));
for (String classValue : classCardinalities.keySet()) {
double expectedClassTrainingCount = totalRowsPerClass.get(classValue) * trainingFraction;
int classTrainingCount = trainingRowsPerClass.get(classValue);
assertThat((double) classTrainingCount, is(closeTo(expectedClassTrainingCount, 1.0)));
}
}
public void testProcess_SelectsTrainingRowsUniformly() {
double trainingPercent = 50.0;
int runCount = 500;
int[] trainingCountPerRow = new int[ROWS_COUNT];
for (int run = 0; run < runCount; run++) {
randomizeSeed = randomLong();
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent);
for (int i = 0; i < classValuesPerRow.length; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
String value = fieldIndex == dependentVariableIndex ? classValuesPerRow[i] : randomAlphaOfLength(10);
row[fieldIndex] = value;
}
String[] processedRow = Arrays.copyOf(row, row.length);
crossValidationSplitter.process(processedRow, this::incrementTrainingDocsCount, this::incrementTestDocsCount);
if (processedRow[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE) {
trainingCountPerRow[i]++;
}
}
}
// We expect each data row to be selected uniformly.
// Thus the fraction of the row count where it's selected for training against the number of runs
// should be close to the training percent, which is set to 0.5
for (int rowTrainingCount : trainingCountPerRow) {
double meanCount = rowTrainingCount / (double) runCount;
assertThat(meanCount, is(closeTo(0.5, 0.1)));
}
}
private CrossValidationSplitter createSplitter(double trainingPercent) {
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCardinalities, trainingPercent, randomizeSeed);
}
private void incrementTrainingDocsCount() {
trainingDocsCount++;
}
private void incrementTestDocsCount() {
testDocsCount++;
}
}