Renames and moves the cross validation splitter package. First, the package and classes are renamed from using "cross validation splitter" to "train test splitter". Cross validation as a term is overloaded and encompasses more concepts than what we are trying to do here. Second, the package used to be under `process` but it does not make sense to be there, it can be a top level package under `dataframe`. Backport of #59529
This commit is contained in:
parent
37406487b9
commit
ee4610c0ca
|
@ -29,7 +29,7 @@ import org.elasticsearch.search.sort.SortOrder;
|
|||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.ml.dataframe.DestinationIndex;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
|
||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitter;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
|
||||
|
@ -66,14 +66,14 @@ public class DataFrameDataExtractor {
|
|||
private boolean isCancelled;
|
||||
private boolean hasNext;
|
||||
private boolean searchHasShardFailure;
|
||||
private final CachedSupplier<CrossValidationSplitter> crossValidationSplitter;
|
||||
private final CachedSupplier<TrainTestSplitter> trainTestSplitter;
|
||||
|
||||
DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.context = Objects.requireNonNull(context);
|
||||
hasNext = true;
|
||||
searchHasShardFailure = false;
|
||||
this.crossValidationSplitter = new CachedSupplier<>(context.crossValidationSplitterFactory::create);
|
||||
this.trainTestSplitter = new CachedSupplier<>(context.trainTestSplitterFactory::create);
|
||||
}
|
||||
|
||||
public Map<String, String> getHeaders() {
|
||||
|
@ -207,7 +207,7 @@ public class DataFrameDataExtractor {
|
|||
}
|
||||
}
|
||||
}
|
||||
boolean isTraining = extractedValues == null ? false : crossValidationSplitter.get().isTraining(extractedValues);
|
||||
boolean isTraining = extractedValues == null ? false : trainTestSplitter.get().isTraining(extractedValues);
|
||||
return new Row(extractedValues, hit, isTraining);
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
package org.elasticsearch.xpack.ml.dataframe.extractor;
|
||||
|
||||
import org.elasticsearch.index.query.QueryBuilder;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -23,11 +23,11 @@ public class DataFrameDataExtractorContext {
|
|||
final Map<String, String> headers;
|
||||
final boolean includeSource;
|
||||
final boolean supportsRowsWithMissingValues;
|
||||
final CrossValidationSplitterFactory crossValidationSplitterFactory;
|
||||
final TrainTestSplitterFactory trainTestSplitterFactory;
|
||||
|
||||
DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, QueryBuilder query, int scrollSize,
|
||||
Map<String, String> headers, boolean includeSource, boolean supportsRowsWithMissingValues,
|
||||
CrossValidationSplitterFactory crossValidationSplitterFactory) {
|
||||
TrainTestSplitterFactory trainTestSplitterFactory) {
|
||||
this.jobId = Objects.requireNonNull(jobId);
|
||||
this.extractedFields = Objects.requireNonNull(extractedFields);
|
||||
this.indices = indices.toArray(new String[indices.size()]);
|
||||
|
@ -36,6 +36,6 @@ public class DataFrameDataExtractorContext {
|
|||
this.headers = headers;
|
||||
this.includeSource = includeSource;
|
||||
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
|
||||
this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory);
|
||||
this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ import org.elasticsearch.index.query.QueryBuilder;
|
|||
import org.elasticsearch.index.query.QueryBuilders;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
|
||||
|
@ -33,12 +33,12 @@ public class DataFrameDataExtractorFactory {
|
|||
private final List<RequiredField> requiredFields;
|
||||
private final Map<String, String> headers;
|
||||
private final boolean supportsRowsWithMissingValues;
|
||||
private final CrossValidationSplitterFactory crossValidationSplitterFactory;
|
||||
private final TrainTestSplitterFactory trainTestSplitterFactory;
|
||||
|
||||
private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, QueryBuilder sourceQuery,
|
||||
ExtractedFields extractedFields, List<RequiredField> requiredFields, Map<String, String> headers,
|
||||
boolean supportsRowsWithMissingValues,
|
||||
CrossValidationSplitterFactory crossValidationSplitterFactory) {
|
||||
TrainTestSplitterFactory trainTestSplitterFactory) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.analyticsId = Objects.requireNonNull(analyticsId);
|
||||
this.indices = Objects.requireNonNull(indices);
|
||||
|
@ -47,7 +47,7 @@ public class DataFrameDataExtractorFactory {
|
|||
this.requiredFields = Objects.requireNonNull(requiredFields);
|
||||
this.headers = headers;
|
||||
this.supportsRowsWithMissingValues = supportsRowsWithMissingValues;
|
||||
this.crossValidationSplitterFactory = Objects.requireNonNull(crossValidationSplitterFactory);
|
||||
this.trainTestSplitterFactory = Objects.requireNonNull(trainTestSplitterFactory);
|
||||
}
|
||||
|
||||
public DataFrameDataExtractor newExtractor(boolean includeSource) {
|
||||
|
@ -60,7 +60,7 @@ public class DataFrameDataExtractorFactory {
|
|||
headers,
|
||||
includeSource,
|
||||
supportsRowsWithMissingValues,
|
||||
crossValidationSplitterFactory
|
||||
trainTestSplitterFactory
|
||||
);
|
||||
return new DataFrameDataExtractor(client, context);
|
||||
}
|
||||
|
@ -89,12 +89,12 @@ public class DataFrameDataExtractorFactory {
|
|||
ExtractedFields extractedFields) {
|
||||
return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()),
|
||||
config.getSource().getParsedQuery(), extractedFields, config.getAnalysis().getRequiredFields(), config.getHeaders(),
|
||||
config.getAnalysis().supportsMissingValues(), createCrossValidationSplitterFactory(client, config, extractedFields));
|
||||
config.getAnalysis().supportsMissingValues(), createTrainTestSplitterFactory(client, config, extractedFields));
|
||||
}
|
||||
|
||||
private static CrossValidationSplitterFactory createCrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config,
|
||||
ExtractedFields extractedFields) {
|
||||
return new CrossValidationSplitterFactory(client, config,
|
||||
private static TrainTestSplitterFactory createTrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config,
|
||||
ExtractedFields extractedFields) {
|
||||
return new TrainTestSplitterFactory(client, config,
|
||||
extractedFields.getAllFields().stream().map(ExtractedField::getName).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ public class DataFrameDataExtractorFactory {
|
|||
DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(),
|
||||
Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields,
|
||||
config.getAnalysis().getRequiredFields(), config.getHeaders(), config.getAnalysis().supportsMissingValues(),
|
||||
createCrossValidationSplitterFactory(client, config, extractedFields));
|
||||
createTrainTestSplitterFactory(client, config, extractedFields));
|
||||
listener.onResponse(extractorFactory);
|
||||
},
|
||||
listener::onFailure
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
|
@ -17,14 +17,14 @@ import java.util.Random;
|
|||
* is based on the reservoir idea. It randomly picks training docs while
|
||||
* respecting the exact training percent.
|
||||
*/
|
||||
abstract class AbstractReservoirCrossValidationSplitter implements CrossValidationSplitter {
|
||||
abstract class AbstractReservoirTrainTestSplitter implements TrainTestSplitter {
|
||||
|
||||
protected final int dependentVariableIndex;
|
||||
private final double samplingRatio;
|
||||
private final Random random;
|
||||
|
||||
AbstractReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
|
||||
long randomizeSeed) {
|
||||
AbstractReservoirTrainTestSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
|
||||
long randomizeSeed) {
|
||||
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
|
||||
this.samplingRatio = trainingPercent / 100.0;
|
|
@ -4,16 +4,16 @@
|
|||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class SingleClassReservoirCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
|
||||
public class SingleClassReservoirTrainTestSplitter extends AbstractReservoirTrainTestSplitter {
|
||||
|
||||
private final SampleInfo sampleInfo;
|
||||
|
||||
SingleClassReservoirCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
|
||||
long randomizeSeed, long classCount) {
|
||||
SingleClassReservoirTrainTestSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent,
|
||||
long randomizeSeed, long classCount) {
|
||||
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
|
||||
sampleInfo = new SampleInfo(classCount);
|
||||
}
|
|
@ -4,7 +4,7 @@
|
|||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
@ -14,12 +14,12 @@ import java.util.Map;
|
|||
* Given a dependent variable, randomly splits the dataset trying
|
||||
* to preserve the proportion of each class in the training sample.
|
||||
*/
|
||||
public class StratifiedCrossValidationSplitter extends AbstractReservoirCrossValidationSplitter {
|
||||
public class StratifiedTrainTestSplitter extends AbstractReservoirTrainTestSplitter {
|
||||
|
||||
private final Map<String, SampleInfo> classSamples;
|
||||
|
||||
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
|
||||
double trainingPercent, long randomizeSeed) {
|
||||
public StratifiedTrainTestSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCounts,
|
||||
double trainingPercent, long randomizeSeed) {
|
||||
super(fieldNames, dependentVariable, trainingPercent, randomizeSeed);
|
||||
this.classSamples = new HashMap<>();
|
||||
classCounts.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new SampleInfo(entry.getValue())));
|
|
@ -3,12 +3,12 @@
|
|||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
|
||||
|
||||
/**
|
||||
* Processes rows in order to split the dataset in training and test subsets
|
||||
*/
|
||||
public interface CrossValidationSplitter {
|
||||
public interface TrainTestSplitter {
|
||||
|
||||
boolean isTraining(String[] row);
|
||||
}
|
|
@ -3,7 +3,7 @@
|
|||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
@ -26,21 +26,21 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class CrossValidationSplitterFactory {
|
||||
public class TrainTestSplitterFactory {
|
||||
|
||||
private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);
|
||||
private static final Logger LOGGER = LogManager.getLogger(TrainTestSplitterFactory.class);
|
||||
|
||||
private final Client client;
|
||||
private final DataFrameAnalyticsConfig config;
|
||||
private final List<String> fieldNames;
|
||||
|
||||
public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
|
||||
public TrainTestSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.config = Objects.requireNonNull(config);
|
||||
this.fieldNames = Objects.requireNonNull(fieldNames);
|
||||
}
|
||||
|
||||
public CrossValidationSplitter create() {
|
||||
public TrainTestSplitter create() {
|
||||
if (config.getAnalysis() instanceof Regression) {
|
||||
return createSingleClassSplitter((Regression) config.getAnalysis());
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ public class CrossValidationSplitterFactory {
|
|||
return row -> true;
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createSingleClassSplitter(Regression regression) {
|
||||
private TrainTestSplitter createSingleClassSplitter(Regression regression) {
|
||||
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
|
||||
.setSize(0)
|
||||
.setAllowPartialSearchResults(false)
|
||||
|
@ -60,7 +60,7 @@ public class CrossValidationSplitterFactory {
|
|||
try {
|
||||
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
|
||||
searchRequestBuilder::get);
|
||||
return new SingleClassReservoirCrossValidationSplitter(fieldNames, regression.getDependentVariable(),
|
||||
return new SingleClassReservoirTrainTestSplitter(fieldNames, regression.getDependentVariable(),
|
||||
regression.getTrainingPercent(), regression.getRandomizeSeed(), searchResponse.getHits().getTotalHits().value);
|
||||
} catch (Exception e) {
|
||||
ParameterizedMessage msg = new ParameterizedMessage("[{}] Error searching total number of training docs", config.getId());
|
||||
|
@ -69,7 +69,7 @@ public class CrossValidationSplitterFactory {
|
|||
}
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
|
||||
private TrainTestSplitter createStratifiedSplitter(Classification classification) {
|
||||
String aggName = "dependent_variable_terms";
|
||||
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
|
||||
.setSize(0)
|
||||
|
@ -88,7 +88,7 @@ public class CrossValidationSplitterFactory {
|
|||
classCounts.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
|
||||
}
|
||||
|
||||
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCounts,
|
||||
return new StratifiedTrainTestSplitter(fieldNames, classification.getDependentVariable(), classCounts,
|
||||
classification.getTrainingPercent(), classification.getRandomizeSeed());
|
||||
} catch (Exception e) {
|
||||
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
|
|
@ -27,7 +27,7 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.traintestsplit.TrainTestSplitterFactory;
|
||||
import org.elasticsearch.xpack.ml.extractor.DocValueField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
|
||||
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
|
||||
|
@ -67,7 +67,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
private QueryBuilder query;
|
||||
private int scrollSize;
|
||||
private Map<String, String> headers;
|
||||
private CrossValidationSplitterFactory crossValidationSplitterFactory;
|
||||
private TrainTestSplitterFactory trainTestSplitterFactory;
|
||||
private ArgumentCaptor<ClearScrollRequest> capturedClearScrollRequests;
|
||||
private ActionFuture<ClearScrollResponse> clearScrollFuture;
|
||||
|
||||
|
@ -87,8 +87,8 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
scrollSize = 1000;
|
||||
headers = Collections.emptyMap();
|
||||
|
||||
crossValidationSplitterFactory = mock(CrossValidationSplitterFactory.class);
|
||||
when(crossValidationSplitterFactory.create()).thenReturn(row -> true);
|
||||
trainTestSplitterFactory = mock(TrainTestSplitterFactory.class);
|
||||
when(trainTestSplitterFactory.create()).thenReturn(row -> true);
|
||||
|
||||
clearScrollFuture = mock(ActionFuture.class);
|
||||
capturedClearScrollRequests = ArgumentCaptor.forClass(ClearScrollRequest.class);
|
||||
|
@ -467,7 +467,7 @@ public class DataFrameDataExtractorTests extends ESTestCase {
|
|||
|
||||
private TestExtractor createExtractor(boolean includeSource, boolean supportsRowsWithMissingValues) {
|
||||
DataFrameDataExtractorContext context = new DataFrameDataExtractorContext(JOB_ID, extractedFields, indices, query, scrollSize,
|
||||
headers, includeSource, supportsRowsWithMissingValues, crossValidationSplitterFactory);
|
||||
headers, includeSource, supportsRowsWithMissingValues, trainTestSplitterFactory);
|
||||
return new TestExtractor(client, context);
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
|
@ -17,7 +17,7 @@ import java.util.stream.IntStream;
|
|||
import static org.hamcrest.Matchers.closeTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase {
|
||||
public class SingleClassReservoirTrainTestSplitterTests extends ESTestCase {
|
||||
|
||||
private List<String> fields;
|
||||
private int dependentVariableIndex;
|
||||
|
@ -37,7 +37,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
|
|||
}
|
||||
|
||||
public void testIsTraining_GivenRowsWithoutDependentVariableValue() {
|
||||
CrossValidationSplitter splitter = createSplitter(50.0, 0);
|
||||
TrainTestSplitter splitter = createSplitter(50.0, 0);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -53,7 +53,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
|
|||
}
|
||||
|
||||
public void testIsTraining_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||
CrossValidationSplitter splitter = createSplitter(100.0, 100L);
|
||||
TrainTestSplitter splitter = createSplitter(100.0, 100L);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -75,7 +75,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
|
|||
int runCount = 20;
|
||||
int[] trainingRowsPerRun = new int[runCount];
|
||||
for (int testIndex = 0; testIndex < runCount; testIndex++) {
|
||||
CrossValidationSplitter splitter = createSplitter(trainingPercent, rowCount);
|
||||
TrainTestSplitter splitter = createSplitter(trainingPercent, rowCount);
|
||||
int trainingRows = 0;
|
||||
for (int i = 0; i < rowCount; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -99,7 +99,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
|
|||
}
|
||||
|
||||
public void testIsTraining_ShouldHaveAtLeastOneTrainingRow() {
|
||||
CrossValidationSplitter splitter = createSplitter(1.0, 1);
|
||||
TrainTestSplitter splitter = createSplitter(1.0, 1);
|
||||
|
||||
// We have some non-training rows and then a training row to check
|
||||
// we maintain the first training row and not just the first row
|
||||
|
@ -121,7 +121,7 @@ public class SingleClassReservoirCrossValidationSplitterTests extends ESTestCase
|
|||
}
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createSplitter(double trainingPercent, long classCount) {
|
||||
return new SingleClassReservoirCrossValidationSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount);
|
||||
private TrainTestSplitter createSplitter(double trainingPercent, long classCount) {
|
||||
return new SingleClassReservoirTrainTestSplitter(fields, dependentVariable, trainingPercent, randomizeSeed, classCount);
|
||||
}
|
||||
}
|
|
@ -4,7 +4,7 @@
|
|||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
|
||||
package org.elasticsearch.xpack.ml.dataframe.traintestsplit;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
|
@ -24,7 +24,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
|
|||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThanOrEqualTo;
|
||||
|
||||
public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
||||
public class StratifiedTrainTestSplitterTests extends ESTestCase {
|
||||
|
||||
private static final int ROWS_COUNT = 500;
|
||||
|
||||
|
@ -73,13 +73,13 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testConstructor_GivenMissingDependentVariable() {
|
||||
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedCrossValidationSplitter(
|
||||
ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> new StratifiedTrainTestSplitter(
|
||||
Collections.emptyList(), "foo", Collections.emptyMap(), 100.0, 0));
|
||||
assertThat(e.getMessage(), equalTo("Could not find dependent variable [foo] in fields []"));
|
||||
}
|
||||
|
||||
public void testIsTraining_GivenUnknownClass() {
|
||||
CrossValidationSplitter splitter = createSplitter(100.0);
|
||||
TrainTestSplitter splitter = createSplitter(100.0);
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
row[fieldIndex] = randomAlphaOfLength(5);
|
||||
|
@ -93,7 +93,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testIsTraining_GivenRowsWithoutDependentVariableValue() {
|
||||
CrossValidationSplitter splitter = createSplitter(50.0);
|
||||
TrainTestSplitter splitter = createSplitter(50.0);
|
||||
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -109,7 +109,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
public void testIsTraining_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||
CrossValidationSplitter splitter = createSplitter(100.0);
|
||||
TrainTestSplitter splitter = createSplitter(100.0);
|
||||
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -128,7 +128,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
// We don't go too low here to avoid flakiness
|
||||
double trainingPercent = randomDoubleBetween(50.0, 100.0, true);
|
||||
|
||||
CrossValidationSplitter splitter = createSplitter(trainingPercent);
|
||||
TrainTestSplitter splitter = createSplitter(trainingPercent);
|
||||
|
||||
Map<String, Integer> totalRowsPerClass = new HashMap<>();
|
||||
Map<String, Integer> trainingRowsPerClass = new HashMap<>();
|
||||
|
@ -189,7 +189,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
for (int run = 0; run < runCount; run++) {
|
||||
|
||||
randomizeSeed = randomLong();
|
||||
CrossValidationSplitter crossValidationSplitter = createSplitter(trainingPercent);
|
||||
TrainTestSplitter trainTestSplitter = createSplitter(trainingPercent);
|
||||
|
||||
for (int i = 0; i < classValuesPerRow.length; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
|
@ -199,7 +199,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
boolean isTraining = crossValidationSplitter.isTraining(processedRow);
|
||||
boolean isTraining = trainTestSplitter.isTraining(processedRow);
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
|
||||
if (isTraining) {
|
||||
|
@ -223,7 +223,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
classCounts = new HashMap<>();
|
||||
classCounts.put("class_a", 1L);
|
||||
classCounts.put("class_b", 1L);
|
||||
CrossValidationSplitter splitter = createSplitter(80.0);
|
||||
TrainTestSplitter splitter = createSplitter(80.0);
|
||||
|
||||
{
|
||||
String[] row = new String[]{"class_a", "42.0"};
|
||||
|
@ -245,7 +245,7 @@ public class StratifiedCrossValidationSplitterTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private CrossValidationSplitter createSplitter(double trainingPercent) {
|
||||
return new StratifiedCrossValidationSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed);
|
||||
private TrainTestSplitter createSplitter(double trainingPercent) {
|
||||
return new StratifiedTrainTestSplitter(fields, dependentVariable, classCounts, trainingPercent, randomizeSeed);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue