diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 9ccaf2ebb9d..a431cdbb965 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -29,7 +29,7 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.ml.dataframe.DestinationIndex; -import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter; +import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; @@ -66,14 +66,14 @@ public class DataFrameDataExtractor { private boolean isCancelled; private boolean hasNext; private boolean searchHasShardFailure; - private final CachedSupplier crossValidationSplitter; + private final CachedSupplier trainTestSplitter; DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) { this.client = Objects.requireNonNull(client); this.context = Objects.requireNonNull(context); hasNext = true; searchHasShardFailure = false; - this.crossValidationSplitter = new CachedSupplier<>(context.crossValidationSplitterFactory::create); + this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create); } public Map getHeaders() { @@ -207,7 +207,7 @@ public class DataFrameDataExtractor { } } } - boolean isTraining = extractedValues == null ? false : crossValidationSplitter.get().isTraining(extractedValues); + boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues); return new Row(extractedValues, hit, isTraining); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java index 55d2afc65d6..2ad9efea73d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java @@ -6,7 +6,7 @@ package org.elasticsearch.xpack.ml.dataframe.extractor; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory; +import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; import java.util.List; @@ -23,11 +23,11 @@ public class DataFrameDataExtractorContext { final Map headers; final boolean includeSource; final boolean supportsRowsWithMissingValues; - final CrossValidationSplitterFactory crossValidationSplitterFactory; + final TrainTestSplitterFactory trainTestSplitterFactory; DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List indices, QueryBuilder query, int scrollSize, Map headers, boolean includeSource, boolean supportsRowsWithMissingValues, - CrossValidationSplitterFactory crossValidationSplitterFactory) { + TrainTestSplitterFactory trainTestSplitterFactory) { this.jobId = Objects.requireNonNull(jobId); this.extractedFields = Objects.requireNonNull(extractedFields); this.indices = indices.toArray(new String[indices.size()]); @@ -36,6 +36,6 @@ public class DataFrameDataExtractorContext { this.headers = headers; this.includeSource = includeSource; this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; - this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory); + this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 693c52647f3..353491aac8d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -12,7 +12,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField; -import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory; +import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; @@ -33,12 +33,12 @@ public class DataFrameDataExtractorFactory { private final List requiredFields; private final Map headers; private final boolean supportsRowsWithMissingValues; - private final CrossValidationSplitterFactory crossValidationSplitterFactory; + private final TrainTestSplitterFactory trainTestSplitterFactory; private DataFrameDataExtractorFactory(Client client, String analyticsId, List indices, QueryBuilder sourceQuery, ExtractedFields extractedFields, List requiredFields, Map headers, boolean supportsRowsWithMissingValues, - CrossValidationSplitterFactory crossValidationSplitterFactory) { + TrainTestSplitterFactory trainTestSplitterFactory) { this.client = Objects.requireNonNull(client); this.analyticsId = Objects.requireNonNull(analyticsId); this.indices = Objects.requireNonNull(indices); @@ -47,7 +47,7 @@ public class DataFrameDataExtractorFactory { this.requiredFields = Objects.requireNonNull(requiredFields); this.headers = headers; this.supportsRowsWithMissingValues = supportsRowsWithMissingValues; - this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory); + this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory); } public DataFrameDataExtractor newExtractor(boolean includeSource) { @@ -60,7 +60,7 @@ public class DataFrameDataExtractorFactory { headers, includeSource, supportsRowsWithMissingValues, - crossValidationSplitterFactory + trainTestSplitterFactory ); return new DataFrameDataExtractor(client, context); } @@ -89,12 +89,12 @@ public class DataFrameDataExtractorFactory { ExtractedFields extractedFields) { return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()), config.getSource().getParsedQuery(), extractedFields, config.getAnalysis().getRequiredFields(), config.getHeaders(), - config.getAnalysis().supportsMissingValues(), createCrossValidationSplitterFactory(client, config, extractedFields)); + config.getAnalysis().supportsMissingValues(), createTrainTestSplitterFactory(client, config, extractedFields)); } - private static CrossValidationSplitterFactory createCrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, - ExtractedFields extractedFields) { - return new CrossValidationSplitterFactory(client, config, + private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config, + ExtractedFields extractedFields) { + return new TrainTestSplitterFactory(client, config, extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList())); } @@ -118,7 +118,7 @@ public class DataFrameDataExtractorFactory { DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(), Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields, config.getAnalysis().getRequiredFields(), config.getHeaders(), config.getAnalysis().supportsMissingValues(), - createCrossValidationSplitterFactory(client, config, extractedFields)); + createTrainTestSplitterFactory(client, config, extractedFields)); listener.onResponse(extractorFactory); }, listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/AbstractReservoirCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/AbstractReservoirTrainTestSplitter.java similarity index 89% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/AbstractReservoirCrossValidationSplitter.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/AbstractReservoirTrainTestSplitter.java index f2ac127e5b7..8bb4b191a32 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/AbstractReservoirCrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/AbstractReservoirTrainTestSplitter.java @@ -4,7 +4,7 @@ * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; +package org.elasticsearch.xpack.ml.dataframe.traintestsplit; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; @@ -17,14 +17,14 @@ import java.util.Random; * is based on the reservoir idea. It randomly picks training docs while * respecting the exact training percent. */ -abstract class AbstractReservoirCrossValidationSplitter implements CrossValidationSplitter { +abstract class AbstractReservoirTrainTestSplitter implements TrainTestSplitter { protected final int dependentVariableIndex; private final double samplingRatio; private final Random random; - AbstractReservoirCrossValidationSplitter(List fieldNames, String dependentVariable, double trainingPercent, - long randomizeSeed) { + AbstractReservoirTrainTestSplitter(List fieldNames, String dependentVariable, double trainingPercent, + long randomizeSeed) { assert trainingPercent >= 1.0 && trainingPercent <= 100.0; this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); this.samplingRatio = trainingPercent / 100.0; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/SingleClassReservoirTrainTestSplitter.java similarity index 58% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitter.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/SingleClassReservoirTrainTestSplitter.java index 3159f6d89c0..b6f1086ff61 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/SingleClassReservoirTrainTestSplitter.java @@ -4,16 +4,16 @@ * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; +package org.elasticsearch.xpack.ml.dataframe.traintestsplit; import java.util.List; -public class SingleClassReservoirCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter { +public class SingleClassReservoirTrainTestSplitter extends AbstractReservoirTrainTestSplitter { private final SampleInfo sampleInfo; - SingleClassReservoirCrossValidationSplitter(List fieldNames, String dependentVariable, double trainingPercent, - long randomizeSeed, long classCount) { + SingleClassReservoirTrainTestSplitter(List fieldNames, String dependentVariable, double trainingPercent, + long randomizeSeed, long classCount) { super(fieldNames, dependentVariable, trainingPercent, randomizeSeed); sampleInfo = new SampleInfo(classCount); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/StratifiedTrainTestSplitter.java similarity index 74% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/StratifiedTrainTestSplitter.java index 503d6308132..ade1fbe2196 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/StratifiedTrainTestSplitter.java @@ -4,7 +4,7 @@ * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; +package org.elasticsearch.xpack.ml.dataframe.traintestsplit; import java.util.HashMap; import java.util.List; @@ -14,12 +14,12 @@ import java.util.Map; * Given a dependent variable, randomly splits the dataset trying * to preserve the proportion of each class in the training sample. */ -public class StratifiedCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter { +public class StratifiedTrainTestSplitter extends AbstractReservoirTrainTestSplitter { private final Map classSamples; - public StratifiedCrossValidationSplitter(List fieldNames, String dependentVariable, Map classCounts, - double trainingPercent, long randomizeSeed) { + public StratifiedTrainTestSplitter(List fieldNames, String dependentVariable, Map classCounts, + double trainingPercent, long randomizeSeed) { super(fieldNames, dependentVariable, trainingPercent, randomizeSeed); this.classSamples = new HashMap<>(); classCounts.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new SampleInfo(entry.getValue()))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/TrainTestSplitter.java similarity index 76% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/TrainTestSplitter.java index 5bd6a53e984..7d3b97816e8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/TrainTestSplitter.java @@ -3,12 +3,12 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; +package org.elasticsearch.xpack.ml.dataframe.traintestsplit; /** * Processes rows in order to split the dataset in training and test subsets */ -public interface CrossValidationSplitter { +public interface TrainTestSplitter { boolean isTraining(String[] row); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/TrainTestSplitterFactory.java similarity index 83% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/TrainTestSplitterFactory.java index 3408a642804..4e8f9014d7c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/TrainTestSplitterFactory.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; +package org.elasticsearch.xpack.ml.dataframe.traintestsplit; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -26,21 +26,21 @@ import java.util.List; import java.util.Map; import java.util.Objects; -public class CrossValidationSplitterFactory { +public class TrainTestSplitterFactory { - private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class); + private static final Logger LOGGER = LogManager.getLogger(TrainTestSplitterFactory.class); private final Client client; private final DataFrameAnalyticsConfig config; private final List fieldNames; - public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List fieldNames) { + public TrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config, List fieldNames) { this.client = Objects.requireNonNull(client); this.config = Objects.requireNonNull(config); this.fieldNames = Objects.requireNonNull(fieldNames); } - public CrossValidationSplitter create() { + public TrainTestSplitter create() { if (config.getAnalysis() instanceof Regression) { return createSingleClassSplitter((Regression) config.getAnalysis()); } @@ -50,7 +50,7 @@ public class CrossValidationSplitterFactory { return row -> true; } - private CrossValidationSplitter createSingleClassSplitter(Regression regression) { + private TrainTestSplitter createSingleClassSplitter(Regression regression) { SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex()) .setSize(0) .setAllowPartialSearchResults(false) @@ -60,7 +60,7 @@ public class CrossValidationSplitterFactory { try { SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, searchRequestBuilder::get); - return new SingleClassReservoirCrossValidationSplitter(fieldNames, regression.getDependentVariable(), + return new SingleClassReservoirTrainTestSplitter(fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value); } catch (Exception e) { ParameterizedMessage msg = new ParameterizedMessage("[{}] Error searching total number of training docs", config.getId()); @@ -69,7 +69,7 @@ public class CrossValidationSplitterFactory { } } - private CrossValidationSplitter createStratifiedSplitter(Classification classification) { + private TrainTestSplitter createStratifiedSplitter(Classification classification) { String aggName = "dependent_variable_terms"; SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex()) .setSize(0) @@ -88,7 +88,7 @@ public class CrossValidationSplitterFactory { classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount()); } - return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCounts, + return new StratifiedTrainTestSplitter(fieldNames, classification.getDependentVariable(), classCounts, classification.getTrainingPercent(), classification.getRandomizeSeed()); } catch (Exception e) { ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index f83904a6324..7280688e713 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; -import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory; +import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; @@ -67,7 +67,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { private QueryBuilder query; private int scrollSize; private Map headers; - private CrossValidationSplitterFactory crossValidationSplitterFactory; + private TrainTestSplitterFactory trainTestSplitterFactory; private ArgumentCaptor capturedClearScrollRequests; private ActionFuture clearScrollFuture; @@ -87,8 +87,8 @@ public class DataFrameDataExtractorTests extends ESTestCase { scrollSize = 1000; headers = Collections.emptyMap(); - crossValidationSplitterFactory = mock(CrossValidationSplitterFactory.class); - when(crossValidationSplitterFactory.create()).thenReturn(row -> true); + trainTestSplitterFactory = mock(TrainTestSplitterFactory.class); + when(trainTestSplitterFactory.create()).thenReturn(row -> true); clearScrollFuture = mock(ActionFuture.class); capturedClearScrollRequests = ArgumentCaptor.forClass(ClearScrollRequest.class); @@ -467,7 +467,7 @@ public class DataFrameDataExtractorTests extends ESTestCase { private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize, - headers, includeSource, supportsRowsWithMissingValues, crossValidationSplitterFactory); + headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory); return new TestExtractor(client, context); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/SingleClassReservoirTrainTestSplitterTests.java similarity index 87% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitterTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/SingleClassReservoirTrainTestSplitterTests.java index 84df5d32563..32397d57c68 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/SingleClassReservoirCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/SingleClassReservoirTrainTestSplitterTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; +package org.elasticsearch.xpack.ml.dataframe.traintestsplit; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; @@ -17,7 +17,7 @@ import java.util.stream.IntStream; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.is; -public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase { +public class SingleClassReservoirTrainTestSplitterTests extends ESTestCase { private List fields; private int dependentVariableIndex; @@ -37,7 +37,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase } public void testIsTraining_GivenRowsWithoutDependentVariableValue() { - CrossValidationSplitter splitter = createSplitter(50.0, 0); + TrainTestSplitter splitter = createSplitter(50.0, 0); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -53,7 +53,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase } public void testIsTraining_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CrossValidationSplitter splitter = createSplitter(100.0, 100L); + TrainTestSplitter splitter = createSplitter(100.0, 100L); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -75,7 +75,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase int runCount = 20; int[] trainingRowsPerRun = new int[runCount]; for (int testIndex = 0; testIndex < runCount; testIndex++) { - CrossValidationSplitter splitter = createSplitter(trainingPercent, rowCount); + TrainTestSplitter splitter = createSplitter(trainingPercent, rowCount); int trainingRows = 0; for (int i = 0; i < rowCount; i++) { String[] row = new String[fields.size()]; @@ -99,7 +99,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase } public void testIsTraining_ShouldHaveAtLeastOneTrainingRow() { - CrossValidationSplitter splitter = createSplitter(1.0, 1); + TrainTestSplitter splitter = createSplitter(1.0, 1); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row @@ -121,7 +121,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase } } - private CrossValidationSplitter createSplitter(double trainingPercent, long classCount) { - return new SingleClassReservoirCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount); + private TrainTestSplitter createSplitter(double trainingPercent, long classCount) { + return new SingleClassReservoirTrainTestSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/StratifiedTrainTestSplitterTests.java similarity index 90% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/StratifiedTrainTestSplitterTests.java index ad26fd7190c..a15b2338bf9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/traintestsplit/StratifiedTrainTestSplitterTests.java @@ -4,7 +4,7 @@ * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; +package org.elasticsearch.xpack.ml.dataframe.traintestsplit; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.test.ESTestCase; @@ -24,7 +24,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThanOrEqualTo; -public class StratifiedCrossValidationSplitterTests extends ESTestCase { +public class StratifiedTrainTestSplitterTests extends ESTestCase { private static final int ROWS_COUNT = 500; @@ -73,13 +73,13 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } public void testConstructor_GivenMissingDependentVariable() { - ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedCrossValidationSplitter( + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedTrainTestSplitter( Collections.emptyList(), "foo", Collections.emptyMap(), 100.0, 0)); assertThat(e.getMessage(), equalTo("Could not find dependent variable [foo] in fields []")); } public void testIsTraining_GivenUnknownClass() { - CrossValidationSplitter splitter = createSplitter(100.0); + TrainTestSplitter splitter = createSplitter(100.0); String[] row = new String[fields.size()]; for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { row[fieldIndex] = randomAlphaOfLength(5); @@ -93,7 +93,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } public void testIsTraining_GivenRowsWithoutDependentVariableValue() { - CrossValidationSplitter splitter = createSplitter(50.0); + TrainTestSplitter splitter = createSplitter(50.0); for (int i = 0; i < classValuesPerRow.length; i++) { String[] row = new String[fields.size()]; @@ -109,7 +109,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } public void testIsTraining_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CrossValidationSplitter splitter = createSplitter(100.0); + TrainTestSplitter splitter = createSplitter(100.0); for (int i = 0; i < classValuesPerRow.length; i++) { String[] row = new String[fields.size()]; @@ -128,7 +128,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { // We don't go too low here to avoid flakiness double trainingPercent = randomDoubleBetween(50.0, 100.0, true); - CrossValidationSplitter splitter = createSplitter(trainingPercent); + TrainTestSplitter splitter = createSplitter(trainingPercent); Map totalRowsPerClass = new HashMap<>(); Map trainingRowsPerClass = new HashMap<>(); @@ -189,7 +189,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { for (int run = 0; run < runCount; run++) { randomizeSeed = randomLong(); - CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent); + TrainTestSplitter trainTestSplitter = createSplitter(trainingPercent); for (int i = 0; i < classValuesPerRow.length; i++) { String[] row = new String[fields.size()]; @@ -199,7 +199,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } String[] processedRow = Arrays.copyOf(row, row.length); - boolean isTraining = crossValidationSplitter.isTraining(processedRow); + boolean isTraining = trainTestSplitter.isTraining(processedRow); assertThat(Arrays.equals(processedRow, row), is(true)); if (isTraining) { @@ -223,7 +223,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { classCounts = new HashMap<>(); classCounts.put("class_a", 1L); classCounts.put("class_b", 1L); - CrossValidationSplitter splitter = createSplitter(80.0); + TrainTestSplitter splitter = createSplitter(80.0); { String[] row = new String[]{"class_a", "42.0"}; @@ -245,7 +245,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase { } } - private CrossValidationSplitter createSplitter(double trainingPercent) { - return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed); + private TrainTestSplitter createSplitter(double trainingPercent) { + return new StratifiedTrainTestSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed); } }