mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-17 10:25:15 +00:00
Adds a parameter `training_percent` to regression. The default value is `100`. When the parameter is set to a value less than `100`, from the rows that can be used for training (ie. those that have a value for the dependent variable) we randomly choose whether to actually use for training. This enables splitting the data into a training set and the rest, usually called testing, validation or holdout set, which allows for validating the model on data that have not been used for training. Technically, the analytics process considers as training the data that have a value for the dependent variable. Thus, when we decide a training row is not going to be used for training, we simply clear the row's dependent variable.
This commit is contained in:
parent
7b6246ec67
commit
873ad3f942
@ -32,13 +32,15 @@ public class Regression implements DataFrameAnalysis {
|
||||
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
|
||||
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
|
||||
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
|
||||
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
|
||||
|
||||
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
|
||||
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
|
||||
|
||||
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
|
||||
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
|
||||
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6]));
|
||||
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6],
|
||||
(Double) a[7]));
|
||||
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
|
||||
@ -46,6 +48,7 @@ public class Regression implements DataFrameAnalysis {
|
||||
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
|
||||
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
|
||||
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
|
||||
return parser;
|
||||
}
|
||||
|
||||
@ -60,9 +63,11 @@ public class Regression implements DataFrameAnalysis {
|
||||
private final Integer maximumNumberTrees;
|
||||
private final Double featureBagFraction;
|
||||
private final String predictionFieldName;
|
||||
private final double trainingPercent;
|
||||
|
||||
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) {
|
||||
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
|
||||
@Nullable Double trainingPercent) {
|
||||
this.dependentVariable = Objects.requireNonNull(dependentVariable);
|
||||
|
||||
if (lambda != null && lambda < 0) {
|
||||
@ -91,10 +96,15 @@ public class Regression implements DataFrameAnalysis {
|
||||
this.featureBagFraction = featureBagFraction;
|
||||
|
||||
this.predictionFieldName = predictionFieldName;
|
||||
|
||||
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
|
||||
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
|
||||
}
|
||||
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
|
||||
}
|
||||
|
||||
public Regression(String dependentVariable) {
|
||||
this(dependentVariable, null, null, null, null, null, null);
|
||||
this(dependentVariable, null, null, null, null, null, null, null);
|
||||
}
|
||||
|
||||
public Regression(StreamInput in) throws IOException {
|
||||
@ -105,6 +115,15 @@ public class Regression implements DataFrameAnalysis {
|
||||
maximumNumberTrees = in.readOptionalVInt();
|
||||
featureBagFraction = in.readOptionalDouble();
|
||||
predictionFieldName = in.readOptionalString();
|
||||
trainingPercent = in.readDouble();
|
||||
}
|
||||
|
||||
public String getDependentVariable() {
|
||||
return dependentVariable;
|
||||
}
|
||||
|
||||
public double getTrainingPercent() {
|
||||
return trainingPercent;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -121,6 +140,7 @@ public class Regression implements DataFrameAnalysis {
|
||||
out.writeOptionalVInt(maximumNumberTrees);
|
||||
out.writeOptionalDouble(featureBagFraction);
|
||||
out.writeOptionalString(predictionFieldName);
|
||||
out.writeDouble(trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -145,6 +165,7 @@ public class Regression implements DataFrameAnalysis {
|
||||
if (predictionFieldName != null) {
|
||||
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
|
||||
}
|
||||
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
@ -191,7 +212,8 @@ public class Regression implements DataFrameAnalysis {
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName);
|
||||
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
|
||||
trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -205,6 +227,7 @@ public class Regression implements DataFrameAnalysis {
|
||||
&& Objects.equals(eta, that.eta)
|
||||
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
|
||||
&& Objects.equals(featureBagFraction, that.featureBagFraction)
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName);
|
||||
&& Objects.equals(predictionFieldName, that.predictionFieldName)
|
||||
&& trainingPercent == that.trainingPercent;
|
||||
}
|
||||
}
|
||||
|
@ -467,6 +467,9 @@ public class ElasticsearchMappings {
|
||||
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
|
||||
.field(TYPE, KEYWORD)
|
||||
.endObject()
|
||||
.startObject(Regression.TRAINING_PERCENT.getPreferredName())
|
||||
.field(TYPE, DOUBLE)
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
.endObject()
|
||||
|
@ -309,6 +309,7 @@ public final class ReservedFieldNames {
|
||||
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
|
||||
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
|
||||
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
|
||||
Regression.TRAINING_PERCENT.getPreferredName(),
|
||||
|
||||
ElasticsearchMappings.CONFIG_TYPE,
|
||||
|
||||
|
@ -33,8 +33,9 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
|
||||
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
|
||||
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
|
||||
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, true);
|
||||
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
|
||||
predictionFieldName);
|
||||
predictionFieldName, trainingPercent);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -44,57 +45,83 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
|
||||
|
||||
public void testRegression_GivenNegativeLambda() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result"));
|
||||
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenNegativeGamma() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result"));
|
||||
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenEtaIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result"));
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenEtaIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result"));
|
||||
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenMaximumNumberTreesIsZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result"));
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result"));
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result"));
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result"));
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result", 100.0));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsNull() {
|
||||
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", null);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsBoundary() {
|
||||
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 1.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(1.0));
|
||||
regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0);
|
||||
assertThat(regression.getTrainingPercent(), equalTo(100.0));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsLessThanOne() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 0.999));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
|
||||
public void testRegression_GivenTrainingPercentIsGreaterThan100() {
|
||||
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
|
||||
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0001));
|
||||
|
||||
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
|
||||
}
|
||||
}
|
||||
|
@ -71,6 +71,8 @@ integTest.runner {
|
||||
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k',
|
||||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
|
||||
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
|
||||
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
|
||||
'ml/evaluate_data_frame/Test given missing index',
|
||||
'ml/evaluate_data_frame/Test given index does not exist',
|
||||
'ml/evaluate_data_frame/Test given missing evaluation',
|
||||
|
@ -143,12 +143,12 @@ abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTest
|
||||
}
|
||||
|
||||
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
|
||||
@Nullable String resultsField, String dependentVariable) {
|
||||
@Nullable String resultsField, Regression regression) {
|
||||
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
|
||||
configBuilder.setId(id);
|
||||
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
|
||||
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
|
||||
configBuilder.setAnalysis(new Regression(dependentVariable));
|
||||
configBuilder.setAnalysis(regression);
|
||||
return configBuilder.build();
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,234 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.elasticsearch.action.bulk.BulkRequestBuilder;
|
||||
import org.elasticsearch.action.bulk.BulkResponse;
|
||||
import org.elasticsearch.action.get.GetResponse;
|
||||
import org.elasticsearch.action.index.IndexRequest;
|
||||
import org.elasticsearch.action.search.SearchResponse;
|
||||
import org.elasticsearch.action.support.WriteRequest;
|
||||
import org.elasticsearch.search.SearchHit;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
|
||||
|
||||
@After
|
||||
public void cleanup() {
|
||||
cleanUp();
|
||||
}
|
||||
|
||||
public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception {
|
||||
String jobId = "regression_single_numeric_feature_and_mixed_data_set";
|
||||
String sourceIndex = jobId + "_source_index";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
if (i < 300) {
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
} else {
|
||||
indexRequest.source("feature", field);
|
||||
}
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
|
||||
String destIndex = sourceIndex + "_results";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
|
||||
new Regression("variable"));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(jobId, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
|
||||
// TODO reenable this assertion when the backend is stable
|
||||
// it seems for this case values can be as far off as 2.0
|
||||
|
||||
// double featureValue = (double) destDoc.get("feature");
|
||||
// double predictionValue = (double) resultsObject.get("variable_prediction");
|
||||
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
||||
|
||||
boolean expectedIsTraining = destDoc.containsKey("variable");
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(expectedIsTraining));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception {
|
||||
String jobId = "regression_only_training_data_and_training_percent_is_hundred";
|
||||
String sourceIndex = jobId + "_source_index";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
|
||||
String destIndex = sourceIndex + "_results";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
|
||||
new Regression("variable"));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(jobId, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
assertThat(resultsObject.get("is_training"), is(true));
|
||||
}
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
}
|
||||
|
||||
public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception {
|
||||
String jobId = "regression_only_training_data_and_training_percent_is_fifty";
|
||||
String sourceIndex = jobId + "_source_index";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
|
||||
String destIndex = sourceIndex + "_results";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
|
||||
new Regression("variable", null, null, null, null, null, null, 50.0));
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(jobId, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(jobId, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(jobId);
|
||||
waitUntilAnalyticsIsStopped(jobId);
|
||||
|
||||
int trainingRowsCount = 0;
|
||||
int nonTrainingRowsCount = 0;
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
|
||||
assertThat(resultsObject.containsKey("is_training"), is(true));
|
||||
// Let's just assert there's both training and non-training results
|
||||
if ((boolean) resultsObject.get("is_training")) {
|
||||
trainingRowsCount++;
|
||||
} else {
|
||||
nonTrainingRowsCount++;
|
||||
}
|
||||
}
|
||||
assertThat(trainingRowsCount, greaterThan(0));
|
||||
assertThat(nonTrainingRowsCount, greaterThan(0));
|
||||
|
||||
assertProgress(jobId, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
|
||||
}
|
||||
}
|
@ -28,8 +28,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection;
|
||||
import org.junit.After;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.hamcrest.Matchers.allOf;
|
||||
@ -393,77 +391,6 @@ public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTest
|
||||
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
|
||||
}
|
||||
|
||||
public void testRegressionWithNumericFeatureAndFewDocuments() throws Exception {
|
||||
String sourceIndex = "test-regression-with-numeric-feature-and-few-docs";
|
||||
|
||||
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
|
||||
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
|
||||
|
||||
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
|
||||
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
|
||||
|
||||
for (int i = 0; i < 350; i++) {
|
||||
Double field = featureValues.get(i % 3);
|
||||
Double value = dependentVariableValues.get(i % 3);
|
||||
|
||||
IndexRequest indexRequest = new IndexRequest(sourceIndex);
|
||||
if (i < 300) {
|
||||
indexRequest.source("feature", field, "variable", value);
|
||||
} else {
|
||||
indexRequest.source("feature", field);
|
||||
}
|
||||
bulkRequestBuilder.add(indexRequest);
|
||||
}
|
||||
BulkResponse bulkResponse = bulkRequestBuilder.get();
|
||||
if (bulkResponse.hasFailures()) {
|
||||
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
|
||||
}
|
||||
|
||||
String id = "test_regression_with_numeric_feature_and_few_docs";
|
||||
DataFrameAnalyticsConfig config = buildRegressionAnalytics(id, new String[] {sourceIndex},
|
||||
sourceIndex + "-results", null, "variable");
|
||||
registerAnalytics(config);
|
||||
putAnalytics(config);
|
||||
|
||||
assertState(id, DataFrameAnalyticsState.STOPPED);
|
||||
assertProgress(id, 0, 0, 0, 0);
|
||||
|
||||
startAnalytics(id);
|
||||
waitUntilAnalyticsIsStopped(id);
|
||||
|
||||
int resultsWithPrediction = 0;
|
||||
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
|
||||
assertThat(sourceData.getHits().getTotalHits().value, equalTo(350L));
|
||||
for (SearchHit hit : sourceData.getHits()) {
|
||||
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
|
||||
assertThat(destDocGetResponse.isExists(), is(true));
|
||||
Map<String, Object> sourceDoc = hit.getSourceAsMap();
|
||||
Map<String, Object> destDoc = destDocGetResponse.getSource();
|
||||
for (String field : sourceDoc.keySet()) {
|
||||
assertThat(destDoc.containsKey(field), is(true));
|
||||
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
|
||||
}
|
||||
assertThat(destDoc.containsKey("ml"), is(true));
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
|
||||
|
||||
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
|
||||
if (resultsObject.containsKey("variable_prediction")) {
|
||||
resultsWithPrediction++;
|
||||
double featureValue = (double) destDoc.get("feature");
|
||||
double predictionValue = (double) resultsObject.get("variable_prediction");
|
||||
// TODO reenable this assertion when the backend is stable
|
||||
// it seems for this case values can be as far off as 2.0
|
||||
// assertThat(predictionValue, closeTo(10 * featureValue, 2.0));
|
||||
}
|
||||
}
|
||||
assertThat(resultsWithPrediction, greaterThan(0));
|
||||
|
||||
assertProgress(id, 100, 100, 100, 100);
|
||||
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
|
||||
}
|
||||
|
||||
public void testModelMemoryLimitLowerThanEstimatedMemoryUsage() {
|
||||
String sourceIndex = "test-model-memory-limit";
|
||||
|
||||
|
@ -16,11 +16,14 @@ import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.ClientHelper;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -91,7 +94,7 @@ public class AnalyticsProcessManager {
|
||||
try {
|
||||
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
|
||||
writeHeaderRecord(dataExtractor, process);
|
||||
writeDataRows(dataExtractor, process, task.getProgressTracker());
|
||||
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getProgressTracker());
|
||||
process.writeEndOfDataMessage();
|
||||
process.flushStream();
|
||||
|
||||
@ -123,7 +126,10 @@ public class AnalyticsProcessManager {
|
||||
}
|
||||
|
||||
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
|
||||
DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException {
|
||||
DataFrameAnalysis analysis, DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException {
|
||||
|
||||
CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis);
|
||||
|
||||
// The extra fields are for the doc hash and the control field (should be an empty string)
|
||||
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
|
||||
// The value of the control field should be an empty string for data frame rows
|
||||
@ -140,6 +146,7 @@ public class AnalyticsProcessManager {
|
||||
String[] rowValues = row.getValues();
|
||||
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
|
||||
record[record.length - 2] = String.valueOf(row.getChecksum());
|
||||
customProcessor.process(record);
|
||||
process.writeRecord(record);
|
||||
}
|
||||
}
|
||||
|
@ -53,7 +53,6 @@ public class AnalyticsResultProcessor {
|
||||
|
||||
public void process(AnalyticsProcess<AnalyticsResult> process) {
|
||||
long totalRows = process.getConfig().rows();
|
||||
LOGGER.info("Total rows = {}", totalRows);
|
||||
long processedRows = 0;
|
||||
|
||||
// TODO When java 9 features can be used, we will not need the local variable here
|
||||
|
@ -0,0 +1,14 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
/**
|
||||
* A processor to manipulate rows before writing them to the process
|
||||
*/
|
||||
public interface CustomProcessor {
|
||||
|
||||
void process(String[] row);
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
public class CustomProcessorFactory {
|
||||
|
||||
private final List<String> fieldNames;
|
||||
|
||||
public CustomProcessorFactory(List<String> fieldNames) {
|
||||
this.fieldNames = Objects.requireNonNull(fieldNames);
|
||||
}
|
||||
|
||||
public CustomProcessor create(DataFrameAnalysis analysis) {
|
||||
if (analysis instanceof Regression) {
|
||||
return new RegressionCustomProcessor(fieldNames, (Regression) analysis);
|
||||
}
|
||||
return row -> {};
|
||||
}
|
||||
}
|
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
import org.elasticsearch.common.Randomness;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* A processor that randomly clears the dependent variable value
|
||||
* in order to split the dataset in training and validation data.
|
||||
* This relies on the fact that when the dependent variable field
|
||||
* is empty, then the row is not used for training but only to make predictions.
|
||||
*/
|
||||
class RegressionCustomProcessor implements CustomProcessor {
|
||||
|
||||
private static final String EMPTY = "";
|
||||
|
||||
private final int dependentVariableIndex;
|
||||
private final double trainingPercent;
|
||||
private final Random random = Randomness.get();
|
||||
private boolean isFirstRow = true;
|
||||
|
||||
RegressionCustomProcessor(List<String> fieldNames, Regression regression) {
|
||||
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, regression.getDependentVariable());
|
||||
this.trainingPercent = regression.getTrainingPercent();
|
||||
|
||||
}
|
||||
|
||||
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
|
||||
for (int i = 0; i < fieldNames.size(); i++) {
|
||||
if (fieldNames.get(i).equals(dependentVariable)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(String[] row) {
|
||||
if (canBeUsedForTraining(row)) {
|
||||
if (isFirstRow) {
|
||||
// Let's make sure we have at least one training row
|
||||
isFirstRow = false;
|
||||
} else if (isRandomlyExcludedFromTraining()) {
|
||||
row[dependentVariableIndex] = EMPTY;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean canBeUsedForTraining(String[] row) {
|
||||
return row[dependentVariableIndex].length() > 0;
|
||||
}
|
||||
|
||||
private boolean isRandomlyExcludedFromTraining() {
|
||||
return random.nextDouble() * 100 > trainingPercent;
|
||||
}
|
||||
}
|
@ -0,0 +1,149 @@
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
|
||||
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.hamcrest.CoreMatchers.equalTo;
|
||||
import static org.hamcrest.Matchers.both;
|
||||
import static org.hamcrest.Matchers.greaterThan;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.lessThan;
|
||||
|
||||
public class RegressionCustomProcessorTests extends ESTestCase {
|
||||
|
||||
private List<String> fields;
|
||||
private int dependentVariableIndex;
|
||||
private String dependentVariable;
|
||||
|
||||
@Before
|
||||
public void setUpTests() {
|
||||
int fieldCount = randomIntBetween(1, 5);
|
||||
fields = new ArrayList<>(fieldCount);
|
||||
for (int i = 0; i < fieldCount; i++) {
|
||||
fields.add(randomAlphaOfLength(10));
|
||||
}
|
||||
dependentVariableIndex = randomIntBetween(0, fieldCount - 1);
|
||||
dependentVariable = fields.get(dependentVariableIndex);
|
||||
}
|
||||
|
||||
public void testProcess_GivenRowsWithoutDependentVariableValue() {
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 50.0));
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10);
|
||||
row[fieldIndex] = value;
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
customProcessor.process(processedRow);
|
||||
|
||||
// As all these rows have no dependent variable value, they're not for training and should be unaffected
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 100.0));
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
String value = fieldIndex == dependentVariableIndex ? "" : randomAlphaOfLength(10);
|
||||
row[fieldIndex] = value;
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
customProcessor.process(processedRow);
|
||||
|
||||
// We should pick them all as training percent is 100
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
}
|
||||
|
||||
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
|
||||
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
|
||||
double trainingFraction = trainingPercent / 100;
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, trainingPercent));
|
||||
|
||||
int runCount = 20;
|
||||
int rowsCount = 1000;
|
||||
int[] trainingRowsPerRun = new int[runCount];
|
||||
for (int testIndex = 0; testIndex < runCount; testIndex++) {
|
||||
int trainingRows = 0;
|
||||
for (int i = 0; i < rowsCount; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
row[fieldIndex] = randomAlphaOfLength(10);
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
customProcessor.process(processedRow);
|
||||
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
if (fieldIndex != dependentVariableIndex) {
|
||||
assertThat(processedRow[fieldIndex], equalTo(row[fieldIndex]));
|
||||
}
|
||||
}
|
||||
if (processedRow[dependentVariableIndex].length() > 0) {
|
||||
assertThat(processedRow[dependentVariableIndex], equalTo(row[dependentVariableIndex]));
|
||||
trainingRows++;
|
||||
}
|
||||
}
|
||||
trainingRowsPerRun[testIndex] = trainingRows;
|
||||
}
|
||||
|
||||
double meanTrainingRows = IntStream.of(trainingRowsPerRun).average().getAsDouble();
|
||||
|
||||
// Now we need to calculate sensible bounds to assert against.
|
||||
// We'll use 5 variances which should mean the test only fails once in 7M
|
||||
// And, because we're doing multiple runs, we'll divide the variance with the number of runs to narrow the bounds
|
||||
double expectedTrainingRows = trainingFraction * rowsCount;
|
||||
double variance = rowsCount * (Math.pow(1 - trainingFraction, 2) * trainingFraction
|
||||
+ Math.pow(trainingFraction, 2) * (1 - trainingFraction));
|
||||
double lowerBound = expectedTrainingRows - 5 * Math.sqrt(variance / runCount);
|
||||
double upperBound = expectedTrainingRows + 5 * Math.sqrt(variance / runCount);
|
||||
|
||||
assertThat("Mean training rows [" + meanTrainingRows + "] was not within expected bounds of [" + lowerBound + ", "
|
||||
+ upperBound + "] given training fraction was [" + trainingFraction + "]",
|
||||
meanTrainingRows, is(both(greaterThan(lowerBound)).and(lessThan(upperBound))));
|
||||
}
|
||||
|
||||
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
|
||||
CustomProcessor customProcessor = new RegressionCustomProcessor(fields, regression(dependentVariable, 1.0));
|
||||
|
||||
// We have some non-training rows and then a training row to check
|
||||
// we maintain the first training row and not just the first row
|
||||
for (int i = 0; i < 10; i++) {
|
||||
String[] row = new String[fields.size()];
|
||||
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
|
||||
if (i < 9 && fieldIndex == dependentVariableIndex) {
|
||||
row[fieldIndex] = "";
|
||||
} else {
|
||||
row[fieldIndex] = randomAlphaOfLength(10);
|
||||
}
|
||||
}
|
||||
|
||||
String[] processedRow = Arrays.copyOf(row, row.length);
|
||||
customProcessor.process(processedRow);
|
||||
|
||||
assertThat(Arrays.equals(processedRow, row), is(true));
|
||||
}
|
||||
}
|
||||
|
||||
private static Regression regression(String dependentVariable, double trainingPercent) {
|
||||
return new Regression(dependentVariable, null, null, null, null, null, null, trainingPercent);
|
||||
}
|
||||
}
|
@ -1142,6 +1142,52 @@ setup:
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given training_percent is less than one":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-training-percent-is-less-than-one"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"training_percent": 0.999
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given training_percent is greater than hundred":
|
||||
|
||||
- do:
|
||||
catch: /\[training_percent\] must be a double in \[1, 100\]/
|
||||
ml.put_data_frame_analytics:
|
||||
id: "regression-training-percent-is-greater-than-hundred"
|
||||
body: >
|
||||
{
|
||||
"source": {
|
||||
"index": "index-source"
|
||||
},
|
||||
"dest": {
|
||||
"index": "index-dest"
|
||||
},
|
||||
"analysis": {
|
||||
"regression": {
|
||||
"dependent_variable": "foo",
|
||||
"training_percent": 100.1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
---
|
||||
"Test put regression given valid":
|
||||
|
||||
@ -1163,7 +1209,8 @@ setup:
|
||||
"gamma": 0.42,
|
||||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1177,7 +1224,8 @@ setup:
|
||||
"gamma": 0.42,
|
||||
"eta": 0.5,
|
||||
"maximum_number_trees": 400,
|
||||
"feature_bag_fraction": 0.3
|
||||
"feature_bag_fraction": 0.3,
|
||||
"training_percent": 60.3
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
@ -1210,7 +1258,8 @@ setup:
|
||||
- match: { dest.index: "index-dest" }
|
||||
- match: { analysis: {
|
||||
"regression":{
|
||||
"dependent_variable": "foo"
|
||||
"dependent_variable": "foo",
|
||||
"training_percent": 100.0
|
||||
}
|
||||
}}
|
||||
- is_true: create_time
|
||||
|
Loading…
x
Reference in New Issue
Block a user