[7.x][ML] Add option to regression to randomize training set (#45969) (#46017)

Adds a parameter `training_percent` to regression. The default
value is `100`. When the parameter is set to a value less than `100`,
from the rows that can be used for training (ie. those that have a
value for the dependent variable) we randomly choose whether to actually
use for training. This enables splitting the data into a training set and
the rest, usually called testing, validation or holdout set, which allows
for validating the model on data that have not been used for training.

Technically, the analytics process considers as training the data that
have a value for the dependent variable. Thus, when we decide a training
row is not going to be used for training, we simply clear the row's
dependent variable.
This commit is contained in:
Dimitris Athanasiou 2019-08-27 17:53:11 +03:00 committed by GitHub
parent 7b6246ec67
commit 873ad3f942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 622 additions and 95 deletions

View File

@ -32,13 +32,15 @@ public class Regression implements DataFrameAnalysis {
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6]));
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6],
(Double) a[7]));
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
@ -46,6 +48,7 @@ public class Regression implements DataFrameAnalysis {
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
return parser;
}
@ -60,9 +63,11 @@ public class Regression implements DataFrameAnalysis {
private final Integer maximumNumberTrees;
private final Double featureBagFraction;
private final String predictionFieldName;
private final double trainingPercent;
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) {
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
@Nullable Double trainingPercent) {
this.dependentVariable = Objects.requireNonNull(dependentVariable);
if (lambda != null && lambda < 0) {
@ -91,10 +96,15 @@ public class Regression implements DataFrameAnalysis {
this.featureBagFraction = featureBagFraction;
this.predictionFieldName = predictionFieldName;
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
}
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}
public Regression(String dependentVariable) {
this(dependentVariable, null, null, null, null, null, null);
this(dependentVariable, null, null, null, null, null, null, null);
}
public Regression(StreamInput in) throws IOException {
@ -105,6 +115,15 @@ public class Regression implements DataFrameAnalysis {
maximumNumberTrees = in.readOptionalVInt();
featureBagFraction = in.readOptionalDouble();
predictionFieldName = in.readOptionalString();
trainingPercent = in.readDouble();
}
public String getDependentVariable() {
return dependentVariable;
}
public double getTrainingPercent() {
return trainingPercent;
}
@Override
@ -121,6 +140,7 @@ public class Regression implements DataFrameAnalysis {
out.writeOptionalVInt(maximumNumberTrees);
out.writeOptionalDouble(featureBagFraction);
out.writeOptionalString(predictionFieldName);
out.writeDouble(trainingPercent);
}
@Override
@ -145,6 +165,7 @@ public class Regression implements DataFrameAnalysis {
if (predictionFieldName != null) {
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
}
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
builder.endObject();
return builder;
}
@ -191,7 +212,8 @@ public class Regression implements DataFrameAnalysis {
@Override
public int hashCode() {
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName);
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
trainingPercent);
}
@Override
@ -205,6 +227,7 @@ public class Regression implements DataFrameAnalysis {
&& Objects.equals(eta, that.eta)
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
&& Objects.equals(featureBagFraction, that.featureBagFraction)
&& Objects.equals(predictionFieldName, that.predictionFieldName);
&& Objects.equals(predictionFieldName, that.predictionFieldName)
&& trainingPercent == that.trainingPercent;
}
}

View File

@ -467,6 +467,9 @@ public class ElasticsearchMappings {
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
.field(TYPE, KEYWORD)
.endObject()
.startObject(Regression.TRAINING_PERCENT.getPreferredName())
.field(TYPE, DOUBLE)
.endObject()
.endObject()
.endObject()
.endObject()

View File

@ -309,6 +309,7 @@ public final class ReservedFieldNames {
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
Regression.TRAINING_PERCENT.getPreferredName(),
ElasticsearchMappings.CONFIG_TYPE,

View File

@ -33,8 +33,9 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, true);
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
predictionFieldName);
predictionFieldName, trainingPercent);
}
@Override
@ -44,57 +45,83 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
public void testRegression_GivenNegativeLambda() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result"));
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
}
public void testRegression_GivenNegativeGamma() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result"));
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
}
public void testRegression_GivenEtaIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result"));
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenEtaIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result"));
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
}
public void testRegression_GivenMaximumNumberTreesIsZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result"));
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result"));
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result", 100.0));
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
}
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result"));
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result", 100.0));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result"));
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result", 100.0));
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
}
public void testRegression_GivenTrainingPercentIsNull() {
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testRegression_GivenTrainingPercentIsBoundary() {
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 1.0);
assertThat(regression.getTrainingPercent(), equalTo(1.0));
regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testRegression_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 0.999));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testRegression_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0001));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
}

View File

@ -71,6 +71,8 @@ integTest.runner {
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
'ml/evaluate_data_frame/Test given missing index',
'ml/evaluate_data_frame/Test given index does not exist',
'ml/evaluate_data_frame/Test given missing evaluation',

View File

@ -143,12 +143,12 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
}
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
@Nullable String resultsField, String dependentVariable) {
@Nullable String resultsField, Regression regression) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
configBuilder.setId(id);
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(new Regression(dependentVariable));
configBuilder.setAnalysis(regression);
return configBuilder.build();
}
}

View File

@ -0,0 +1,234 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.junit.After;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
@After
public void cleanup() {
cleanUp();
}
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
String jobId = "regression_single_numeric_feature_and_mixed_data_set";
String sourceIndex = jobId + "_source_index";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
if (i < 300) {
indexRequest.source("feature", field, "variable", value);
} else {
indexRequest.source("feature", field);
}
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
String destIndex = sourceIndex + "_results";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
new Regression("variable"));
registerAnalytics(config);
putAnalytics(config);
assertState(jobId, DataFrameAnalyticsState.STOPPED);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
// TODO reenable this assertion when the backend is stable
// it seems for this case values can be as far off as 2.0
// double featureValue = (double) destDoc.get("feature");
// double predictionValue = (double) resultsObject.get("variable_prediction");
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
boolean expectedIsTraining = destDoc.containsKey("variable");
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(expectedIsTraining));
}
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
String jobId = "regression_only_training_data_and_training_percent_is_hundred";
String sourceIndex = jobId + "_source_index";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
indexRequest.source("feature", field, "variable", value);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
String destIndex = sourceIndex + "_results";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
new Regression("variable"));
registerAnalytics(config);
putAnalytics(config);
assertState(jobId, DataFrameAnalyticsState.STOPPED);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
assertThat(resultsObject.get("is_training"), is(true));
}
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
}
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
String jobId = "regression_only_training_data_and_training_percent_is_fifty";
String sourceIndex = jobId + "_source_index";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
indexRequest.source("feature", field, "variable", value);
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
String destIndex = sourceIndex + "_results";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
new Regression("variable", null, null, null, null, null, null, 50.0));
registerAnalytics(config);
putAnalytics(config);
assertState(jobId, DataFrameAnalyticsState.STOPPED);
assertProgress(jobId, 0, 0, 0, 0);
startAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
int trainingRowsCount = 0;
int nonTrainingRowsCount = 0;
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
assertThat(resultsObject.containsKey("is_training"), is(true));
// Let's just assert there's both training and non-training results
if ((boolean) resultsObject.get("is_training")) {
trainingRowsCount++;
} else {
nonTrainingRowsCount++;
}
}
assertThat(trainingRowsCount, greaterThan(0));
assertThat(nonTrainingRowsCount, greaterThan(0));
assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
}
}

View File

@ -28,8 +28,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
import org.junit.After;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.allOf;
@ -393,77 +391,6 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception {
String sourceIndex = "test-regression-with-numeric-feature-and-few-docs";
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
for (int i = 0; i < 350; i++) {
Double field = featureValues.get(i % 3);
Double value = dependentVariableValues.get(i % 3);
IndexRequest indexRequest = new IndexRequest(sourceIndex);
if (i < 300) {
indexRequest.source("feature", field, "variable", value);
} else {
indexRequest.source("feature", field);
}
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}
String id = "test_regression_with_numeric_feature_and_few_docs";
DataFrameAnalyticsConfig config = buildRegressionAnalytics(id, new String[] {sourceIndex},
sourceIndex + "-results", null, "variable");
registerAnalytics(config);
putAnalytics(config);
assertState(id, DataFrameAnalyticsState.STOPPED);
assertProgress(id, 0, 0, 0, 0);
startAnalytics(id);
waitUntilAnalyticsIsStopped(id);
int resultsWithPrediction = 0;
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
assertThat(sourceData.getHits().getTotalHits().value, equalTo(350L));
for (SearchHit hit : sourceData.getHits()) {
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
assertThat(destDocGetResponse.isExists(), is(true));
Map<String, Object> sourceDoc = hit.getSourceAsMap();
Map<String, Object> destDoc = destDocGetResponse.getSource();
for (String field : sourceDoc.keySet()) {
assertThat(destDoc.containsKey(field), is(true));
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
}
assertThat(destDoc.containsKey("ml"), is(true));
@SuppressWarnings("unchecked")
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
if (resultsObject.containsKey("variable_prediction")) {
resultsWithPrediction++;
double featureValue = (double) destDoc.get("feature");
double predictionValue = (double) resultsObject.get("variable_prediction");
// TODO reenable this assertion when the backend is stable
// it seems for this case values can be as far off as 2.0
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
}
}
assertThat(resultsWithPrediction, greaterThan(0));
assertProgress(id, 100, 100, 100, 100);
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
}
public void testModelMemoryLimitLowerThanEstimatedMemoryUsage() {
String sourceIndex = "test-model-memory-limit";

View File

@ -16,11 +16,14 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import java.io.IOException;
@ -91,7 +94,7 @@ public class AnalyticsProcessManager {
try {
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process, task.getProgressTracker());
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getProgressTracker());
process.writeEndOfDataMessage();
process.flushStream();
@ -123,7 +126,10 @@ public class AnalyticsProcessManager {
}
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException {
DataFrameAnalysis analysis, DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException {
CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis);
// The extra fields are for the doc hash and the control field (should be an empty string)
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
// The value of the control field should be an empty string for data frame rows
@ -140,6 +146,7 @@ public class AnalyticsProcessManager {
String[] rowValues = row.getValues();
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
record[record.length - 2] = String.valueOf(row.getChecksum());
customProcessor.process(record);
process.writeRecord(record);
}
}

View File

@ -53,7 +53,6 @@ public class AnalyticsResultProcessor {
public void process(AnalyticsProcess<AnalyticsResult> process) {
long totalRows = process.getConfig().rows();
LOGGER.info("Total rows = {}", totalRows);
long processedRows = 0;
// TODO When java 9 features can be used, we will not need the local variable here

View File

@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
/**
* A processor to manipulate rows before writing them to the process
*/
public interface CustomProcessor {
void process(String[] row);
}

View File

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import java.util.List;
import java.util.Objects;
public class CustomProcessorFactory {
private final List<String> fieldNames;
public CustomProcessorFactory(List<String> fieldNames) {
this.fieldNames = Objects.requireNonNull(fieldNames);
}
public CustomProcessor create(DataFrameAnalysis analysis) {
if (analysis instanceof Regression) {
return new RegressionCustomProcessor(fieldNames, (Regression) analysis);
}
return row -> {};
}
}

View File

@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.List;
import java.util.Random;
/**
* A processor that randomly clears the dependent variable value
* in order to split the dataset in training and validation data.
* This relies on the fact that when the dependent variable field
* is empty, then the row is not used for training but only to make predictions.
*/
class RegressionCustomProcessor implements CustomProcessor {
private static final String EMPTY = "";
private final int dependentVariableIndex;
private final double trainingPercent;
private final Random random = Randomness.get();
private boolean isFirstRow = true;
RegressionCustomProcessor(List<String> fieldNames, Regression regression) {
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, regression.getDependentVariable());
this.trainingPercent = regression.getTrainingPercent();
}
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
for (int i = 0; i < fieldNames.size(); i++) {
if (fieldNames.get(i).equals(dependentVariable)) {
return i;
}
}
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
}
@Override
public void process(String[] row) {
if (canBeUsedForTraining(row)) {
if (isFirstRow) {
// Let's make sure we have at least one training row
isFirstRow = false;
} else if (isRandomlyExcludedFromTraining()) {
row[dependentVariableIndex] = EMPTY;
}
}
}
private boolean canBeUsedForTraining(String[] row) {
return row[dependentVariableIndex].length() > 0;
}
private boolean isRandomlyExcludedFromTraining() {
return random.nextDouble() * 100 > trainingPercent;
}
}

View File

@ -0,0 +1,149 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.junit.Before;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.Matchers.both;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;
public class RegressionCustomProcessorTests extends ESTestCase {
private List<String> fields;
private int dependentVariableIndex;
private String dependentVariable;
@Before
public void setUpTests() {
int fieldCount = randomIntBetween(1, 5);
fields = new ArrayList<>(fieldCount);
for (int i = 0; i < fieldCount; i++) {
fields.add(randomAlphaOfLength(10));
}
dependentVariableIndex = randomIntBetween(0, fieldCount - 1);
dependentVariable = fields.get(dependentVariableIndex);
}
public void testProcess_GivenRowsWithoutDependentVariableValue() {
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 50.0));
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10);
row[fieldIndex] = value;
}
String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
// As all these rows have no dependent variable value, they're not for training and should be unaffected
assertThat(Arrays.equals(processedRow, row), is(true));
}
}
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 100.0));
for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10);
row[fieldIndex] = value;
}
String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
// We should pick them all as training percent is 100
assertThat(Arrays.equals(processedRow, row), is(true));
}
}
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
double trainingFraction = trainingPercent / 100;
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, trainingPercent));
int runCount = 20;
int rowsCount = 1000;
int[] trainingRowsPerRun = new int[runCount];
for (int testIndex = 0; testIndex < runCount; testIndex++) {
int trainingRows = 0;
for (int i = 0; i < rowsCount; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
row[fieldIndex] = randomAlphaOfLength(10);
}
String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
if (fieldIndex != dependentVariableIndex) {
assertThat(processedRow[fieldIndex], equalTo(row[fieldIndex]));
}
}
if (processedRow[dependentVariableIndex].length() > 0) {
assertThat(processedRow[dependentVariableIndex], equalTo(row[dependentVariableIndex]));
trainingRows++;
}
}
trainingRowsPerRun[testIndex] = trainingRows;
}
double meanTrainingRows = IntStream.of(trainingRowsPerRun).average().getAsDouble();
// Now we need to calculate sensible bounds to assert against.
// We'll use 5 variances which should mean the test only fails once in 7M
// And, because we're doing multiple runs, we'll divide the variance with the number of runs to narrow the bounds
double expectedTrainingRows = trainingFraction * rowsCount;
double variance = rowsCount * (Math.pow(1 - trainingFraction, 2) * trainingFraction
+ Math.pow(trainingFraction, 2) * (1 - trainingFraction));
double lowerBound = expectedTrainingRows - 5 * Math.sqrt(variance / runCount);
double upperBound = expectedTrainingRows + 5 * Math.sqrt(variance / runCount);
assertThat("Mean training rows [" + meanTrainingRows + "] was not within expected bounds of [" + lowerBound + ", "
+ upperBound + "] given training fraction was [" + trainingFraction + "]",
meanTrainingRows, is(both(greaterThan(lowerBound)).and(lessThan(upperBound))));
}
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 1.0));
// We have some non-training rows and then a training row to check
// we maintain the first training row and not just the first row
for (int i = 0; i < 10; i++) {
String[] row = new String[fields.size()];
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
if (i < 9 && fieldIndex == dependentVariableIndex) {
row[fieldIndex] = "";
} else {
row[fieldIndex] = randomAlphaOfLength(10);
}
}
String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
assertThat(Arrays.equals(processedRow, row), is(true));
}
}
private static Regression regression(String dependentVariable, double trainingPercent) {
return new Regression(dependentVariable, null, null, null, null, null, null, trainingPercent);
}
}

View File

@ -1142,6 +1142,52 @@ setup:
}
}
---
"Test put regression given training_percent is less than one":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
ml.put_data_frame_analytics:
id: "regression-training-percent-is-less-than-one"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"training_percent": 0.999
}
}
}
---
"Test put regression given training_percent is greater than hundred":
- do:
catch: /\[training_percent\] must be a double in \[1, 100\]/
ml.put_data_frame_analytics:
id: "regression-training-percent-is-greater-than-hundred"
body: >
{
"source": {
"index": "index-source"
},
"dest": {
"index": "index-dest"
},
"analysis": {
"regression": {
"dependent_variable": "foo",
"training_percent": 100.1
}
}
}
---
"Test put regression given valid":
@ -1163,7 +1209,8 @@ setup:
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3
"feature_bag_fraction": 0.3,
"training_percent": 60.3
}
}
}
@ -1177,7 +1224,8 @@ setup:
"gamma": 0.42,
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3
"feature_bag_fraction": 0.3,
"training_percent": 60.3
}
}}
- is_true: create_time
@ -1210,7 +1258,8 @@ setup:
- match: { dest.index: "index-dest" }
- match: { analysis: {
"regression":{
"dependent_variable": "foo"
"dependent_variable": "foo",
"training_percent": 100.0
}
}}
- is_true: create_time