[7.x][ML] Rename cross validation splitter package (#59529) (#59544)

Renames and moves the cross validation splitter package.

First, the package and classes are renamed from using
"cross validation splitter" to "train test splitter".
Cross validation as a term is overloaded and encompasses
more concepts than what we are trying to do here.

Second, the package used to be under `process` but it does
not make sense to be there, it can be a top level package
under `dataframe`.

Backport of #59529
This commit is contained in:
Dimitris Athanasiou 2020-07-14 18:54:46 +03:00 committed by GitHub
parent 37406487b9
commit ee4610c0ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 66 additions and 66 deletions

View File

@ -29,7 +29,7 @@ import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
@ -66,14 +66,14 @@ public class DataFrameDataExtractor {
private boolean isCancelled;
private boolean hasNext;
private boolean searchHasShardFailure;
private final CachedSupplier<CrossValidationSplitter> crossValidationSplitter;
private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
this.client = Objects.requireNonNull(client);
this.context = Objects.requireNonNull(context);
hasNext = true;
searchHasShardFailure = false;
this.crossValidationSplitter = new CachedSupplier<>(context.crossValidationSplitterFactory::create);
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
}
public Map<String, String> getHeaders() {
@ -207,7 +207,7 @@ public class DataFrameDataExtractor {
}
}
}
boolean isTraining = extractedValues == null ? false : crossValidationSplitter.get().isTraining(extractedValues);
boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues);
return new Row(extractedValues, hit, isTraining);
}

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.ml.dataframe.extractor;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import java.util.List;
@ -23,11 +23,11 @@ public class DataFrameDataExtractorContext {
final Map<String, String> headers;
final boolean includeSource;
final boolean supportsRowsWithMissingValues;
final CrossValidationSplitterFactory crossValidationSplitterFactory;
final TrainTestSplitterFactory trainTestSplitterFactory;
DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, QueryBuilder query, int scrollSize,
Map<String, String> headers, boolean includeSource, boolean supportsRowsWithMissingValues,
CrossValidationSplitterFactory crossValidationSplitterFactory) {
TrainTestSplitterFactory trainTestSplitterFactory) {
this.jobId = Objects.requireNonNull(jobId);
this.extractedFields = Objects.requireNonNull(extractedFields);
this.indices = indices.toArray(new String[indices.size()]);
@ -36,6 +36,6 @@ public class DataFrameDataExtractorContext {
this.headers = headers;
this.includeSource = includeSource;
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory);
this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory);
}
}

View File

@ -12,7 +12,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
@ -33,12 +33,12 @@ public class DataFrameDataExtractorFactory {
private final List<RequiredField> requiredFields;
private final Map<String, String> headers;
private final boolean supportsRowsWithMissingValues;
private final CrossValidationSplitterFactory crossValidationSplitterFactory;
private final TrainTestSplitterFactory trainTestSplitterFactory;
private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, QueryBuilder sourceQuery,
ExtractedFields extractedFields, List<RequiredField> requiredFields, Map<String, String> headers,
boolean supportsRowsWithMissingValues,
CrossValidationSplitterFactory crossValidationSplitterFactory) {
TrainTestSplitterFactory trainTestSplitterFactory) {
this.client = Objects.requireNonNull(client);
this.analyticsId = Objects.requireNonNull(analyticsId);
this.indices = Objects.requireNonNull(indices);
@ -47,7 +47,7 @@ public class DataFrameDataExtractorFactory {
this.requiredFields = Objects.requireNonNull(requiredFields);
this.headers = headers;
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory);
this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory);
}
public DataFrameDataExtractor newExtractor(boolean includeSource) {
@ -60,7 +60,7 @@ public class DataFrameDataExtractorFactory {
headers,
includeSource,
supportsRowsWithMissingValues,
crossValidationSplitterFactory
trainTestSplitterFactory
);
return new DataFrameDataExtractor(client, context);
}
@ -89,12 +89,12 @@ public class DataFrameDataExtractorFactory {
ExtractedFields extractedFields) {
return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()),
config.getSource().getParsedQuery(), extractedFields, config.getAnalysis().getRequiredFields(), config.getHeaders(),
config.getAnalysis().supportsMissingValues(), createCrossValidationSplitterFactory(client, config, extractedFields));
config.getAnalysis().supportsMissingValues(), createTrainTestSplitterFactory(client, config, extractedFields));
}
private static CrossValidationSplitterFactory createCrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields) {
return new CrossValidationSplitterFactory(client, config,
private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields) {
return new TrainTestSplitterFactory(client, config,
extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()));
}
@ -118,7 +118,7 @@ public class DataFrameDataExtractorFactory {
DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(),
Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields,
config.getAnalysis().getRequiredFields(), config.getHeaders(), config.getAnalysis().supportsMissingValues(),
createCrossValidationSplitterFactory(client, config, extractedFields));
createTrainTestSplitterFactory(client, config, extractedFields));
listener.onResponse(extractorFactory);
},
listener::onFailure

View File

@ -4,7 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
@ -17,14 +17,14 @@ import java.util.Random;
* is based on the reservoir idea. It randomly picks training docs while
* respecting the exact training percent.
*/
abstract class AbstractReservoirCrossValidationSplitter implements CrossValidationSplitter {
abstract class AbstractReservoirTrainTestSplitter implements TrainTestSplitter {
protected final int dependentVariableIndex;
private final double samplingRatio;
private final Random random;
AbstractReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed) {
AbstractReservoirTrainTestSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed) {
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.samplingRatio = trainingPercent / 100.0;

View File

@ -4,16 +4,16 @@
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
import java.util.List;
public class SingleClassReservoirCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
public class SingleClassReservoirTrainTestSplitter extends AbstractReservoirTrainTestSplitter {
private final SampleInfo sampleInfo;
SingleClassReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed, long classCount) {
SingleClassReservoirTrainTestSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
long randomizeSeed, long classCount) {
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
sampleInfo = new SampleInfo(classCount);
}

View File

@ -4,7 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
import java.util.HashMap;
import java.util.List;
@ -14,12 +14,12 @@ import java.util.Map;
* Given a dependent variable, randomly splits the dataset trying
* to preserve the proportion of each class in the training sample.
*/
public class StratifiedCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
public class StratifiedTrainTestSplitter extends AbstractReservoirTrainTestSplitter {
private final Map<String, SampleInfo> classSamples;
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
double trainingPercent, long randomizeSeed) {
public StratifiedTrainTestSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
double trainingPercent, long randomizeSeed) {
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
this.classSamples = new HashMap<>();
classCounts.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new SampleInfo(entry.getValue())));

View File

@ -3,12 +3,12 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
/**
* Processes rows in order to split the dataset in training and test subsets
*/
public interface CrossValidationSplitter {
public interface TrainTestSplitter {
boolean isTraining(String[] row);
}

View File

@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -26,21 +26,21 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
public class CrossValidationSplitterFactory {
public class TrainTestSplitterFactory {
private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);
private static final Logger LOGGER = LogManager.getLogger(TrainTestSplitterFactory.class);
private final Client client;
private final DataFrameAnalyticsConfig config;
private final List<String> fieldNames;
public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
public TrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
this.client = Objects.requireNonNull(client);
this.config = Objects.requireNonNull(config);
this.fieldNames = Objects.requireNonNull(fieldNames);
}
public CrossValidationSplitter create() {
public TrainTestSplitter create() {
if (config.getAnalysis() instanceof Regression) {
return createSingleClassSplitter((Regression) config.getAnalysis());
}
@ -50,7 +50,7 @@ public class CrossValidationSplitterFactory {
return row -> true;
}
private CrossValidationSplitter createSingleClassSplitter(Regression regression) {
private TrainTestSplitter createSingleClassSplitter(Regression regression) {
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
.setSize(0)
.setAllowPartialSearchResults(false)
@ -60,7 +60,7 @@ public class CrossValidationSplitterFactory {
try {
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
searchRequestBuilder::get);
return new SingleClassReservoirCrossValidationSplitter(fieldNames, regression.getDependentVariable(),
return new SingleClassReservoirTrainTestSplitter(fieldNames, regression.getDependentVariable(),
regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value);
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Error searching total number of training docs", config.getId());
@ -69,7 +69,7 @@ public class CrossValidationSplitterFactory {
}
}
private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
private TrainTestSplitter createStratifiedSplitter(Classification classification) {
String aggName = "dependent_variable_terms";
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
.setSize(0)
@ -88,7 +88,7 @@ public class CrossValidationSplitterFactory {
classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
}
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCounts,
return new StratifiedTrainTestSplitter(fieldNames, classification.getDependentVariable(), classCounts,
classification.getTrainingPercent(), classification.getRandomizeSeed());
} catch (Exception e) {
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());

View File

@ -27,7 +27,7 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
import org.elasticsearch.xpack.ml.extractor.DocValueField;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
@ -67,7 +67,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
private QueryBuilder query;
private int scrollSize;
private Map<String, String> headers;
private CrossValidationSplitterFactory crossValidationSplitterFactory;
private TrainTestSplitterFactory trainTestSplitterFactory;
private ArgumentCaptor<ClearScrollRequest> capturedClearScrollRequests;
private ActionFuture<ClearScrollResponse> clearScrollFuture;
@ -87,8 +87,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
scrollSize = 1000;
headers = Collections.emptyMap();
crossValidationSplitterFactory = mock(CrossValidationSplitterFactory.class);
when(crossValidationSplitterFactory.create()).thenReturn(row -> true);
trainTestSplitterFactory = mock(TrainTestSplitterFactory.class);
when(trainTestSplitterFactory.create()).thenReturn(row -> true);
clearScrollFuture = mock(ActionFuture.class);
capturedClearScrollRequests = ArgumentCaptor.forClass(ClearScrollRequest.class);
@ -467,7 +467,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
headers, includeSource, supportsRowsWithMissingValues, crossValidationSplitterFactory);
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
return new TestExtractor(client, context);
}

View File

@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
@ -17,7 +17,7 @@ import java.util.stream.IntStream;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.is;
public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase {
public class SingleClassReservoirTrainTestSplitterTests extends ESTestCase {
private List<String> fields;
private int dependentVariableIndex;
@ -37,7 +37,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
}
public void testIsTraining_GivenRowsWithoutDependentVariableValue() {
CrossValidationSplitter splitter = createSplitter(50.0, 0);
TrainTestSplitter splitter = createSplitter(50.0, 0);
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
@ -53,7 +53,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
}
public void testIsTraining_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CrossValidationSplitter splitter = createSplitter(100.0, 100L);
TrainTestSplitter splitter = createSplitter(100.0, 100L);
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
@ -75,7 +75,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
int runCount = 20;
int[] trainingRowsPerRun = new int[runCount];
for (int testIndex = 0; testIndex < runCount; testIndex++) {
CrossValidationSplitter splitter = createSplitter(trainingPercent, rowCount);
TrainTestSplitter splitter = createSplitter(trainingPercent, rowCount);
int trainingRows = 0;
for (int i = 0; i < rowCount; i++) {
String[] row = new String[fields.size()];
@ -99,7 +99,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
}
public void testIsTraining_ShouldHaveAtLeastOneTrainingRow() {
CrossValidationSplitter splitter = createSplitter(1.0, 1);
TrainTestSplitter splitter = createSplitter(1.0, 1);
// We have some non-training rows and then a training row to check
// we maintain the first training row and not just the first row
@ -121,7 +121,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
}
}
private CrossValidationSplitter createSplitter(double trainingPercent, long classCount) {
return new SingleClassReservoirCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount);
private TrainTestSplitter createSplitter(double trainingPercent, long classCount) {
return new SingleClassReservoirTrainTestSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount);
}
}

View File

@ -4,7 +4,7 @@
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.test.ESTestCase;
@ -24,7 +24,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
public class StratifiedCrossValidationSplitterTests extends ESTestCase {
public class StratifiedTrainTestSplitterTests extends ESTestCase {
private static final int ROWS_COUNT = 500;
@ -73,13 +73,13 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
}
public void testConstructor_GivenMissingDependentVariable() {
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedCrossValidationSplitter(
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedTrainTestSplitter(
Collections.emptyList(), "foo", Collections.emptyMap(), 100.0, 0));
assertThat(e.getMessage(), equalTo("Could not find dependent variable [foo] in fields []"));
}
public void testIsTraining_GivenUnknownClass() {
CrossValidationSplitter splitter = createSplitter(100.0);
TrainTestSplitter splitter = createSplitter(100.0);
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
row[fieldIndex] = randomAlphaOfLength(5);
@ -93,7 +93,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
}
public void testIsTraining_GivenRowsWithoutDependentVariableValue() {
CrossValidationSplitter splitter = createSplitter(50.0);
TrainTestSplitter splitter = createSplitter(50.0);
for (int i = 0; i < classValuesPerRow.length; i++) {
String[] row = new String[fields.size()];
@ -109,7 +109,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
}
public void testIsTraining_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CrossValidationSplitter splitter = createSplitter(100.0);
TrainTestSplitter splitter = createSplitter(100.0);
for (int i = 0; i < classValuesPerRow.length; i++) {
String[] row = new String[fields.size()];
@ -128,7 +128,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
// We don't go too low here to avoid flakiness
double trainingPercent = randomDoubleBetween(50.0, 100.0, true);
CrossValidationSplitter splitter = createSplitter(trainingPercent);
TrainTestSplitter splitter = createSplitter(trainingPercent);
Map<String, Integer> totalRowsPerClass = new HashMap<>();
Map<String, Integer> trainingRowsPerClass = new HashMap<>();
@ -189,7 +189,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
for (int run = 0; run < runCount; run++) {
randomizeSeed = randomLong();
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent);
TrainTestSplitter trainTestSplitter = createSplitter(trainingPercent);
for (int i = 0; i < classValuesPerRow.length; i++) {
String[] row = new String[fields.size()];
@ -199,7 +199,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
}
String[] processedRow = Arrays.copyOf(row, row.length);
boolean isTraining = crossValidationSplitter.isTraining(processedRow);
boolean isTraining = trainTestSplitter.isTraining(processedRow);
assertThat(Arrays.equals(processedRow, row), is(true));
if (isTraining) {
@ -223,7 +223,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
classCounts = new HashMap<>();
classCounts.put("class_a", 1L);
classCounts.put("class_b", 1L);
CrossValidationSplitter splitter = createSplitter(80.0);
TrainTestSplitter splitter = createSplitter(80.0);
{
String[] row = new String[]{"class_a", "42.0"};
@ -245,7 +245,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
}
}
private CrossValidationSplitter createSplitter(double trainingPercent) {
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed);
private TrainTestSplitter createSplitter(double trainingPercent) {
return new StratifiedTrainTestSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed);
}
}