[7.x] Default "prediction_field_name" to (dependent_variable + "_prediction") (#48232) (#48279)

This commit is contained in:
Przemysław Witek 2019-10-21 13:18:08 +02:00 committed by GitHub
parent 69fc715bc3
commit 1a42e37070
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 17 deletions

View File

@ -126,8 +126,8 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection;
import org.elasticsearch.client.ml.dataframe.PhaseProgress;
import org.elasticsearch.client.ml.dataframe.QueryConfig;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric;
import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression;
import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric;
@ -1297,6 +1297,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.build())
.setDescription("this is a regression")
@ -1331,6 +1332,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
.setIndex("put-test-dest-index")
.build())
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
.setPredictionFieldName("my_dependent_variable_prediction")
.setTrainingPercent(80.0)
.setNumTopClasses(1)
.build())

View File

@ -92,7 +92,7 @@ public class Classification implements DataFrameAnalysis {
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName;
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}
@ -113,6 +113,10 @@ public class Classification implements DataFrameAnalysis {
return dependentVariable;
}
public String getPredictionFieldName() {
return predictionFieldName;
}
public int getNumTopClasses() {
return numTopClasses;
}

View File

@ -70,7 +70,7 @@ public class Regression implements DataFrameAnalysis {
}
this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME);
this.predictionFieldName = predictionFieldName;
this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
}
@ -89,6 +89,10 @@ public class Regression implements DataFrameAnalysis {
return dependentVariable;
}
public String getPredictionFieldName() {
return predictionFieldName;
}
public double getTrainingPercent() {
return trainingPercent;
}

View File

@ -73,6 +73,14 @@ public class ClassificationTests extends AbstractSerializingTestCase<Classificat
assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]"));
}
public void testGetPredictionFieldName() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0);
assertThat(classification.getPredictionFieldName(), equalTo("result"));
classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0);
assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testGetNumTopClasses() {
Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0);
assertThat(classification.getNumTopClasses(), equalTo(7));

View File

@ -19,6 +19,8 @@ import static org.hamcrest.Matchers.nullValue;
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0);
@Override
protected Regression doParseInstance(XContentParser parser) throws IOException {
return Regression.fromXContent(parser, false);
@ -42,32 +44,45 @@ public class RegressionTests extends AbstractSerializingTestCase<Regression> {
return Regression::new;
}
public void testConstructor_GivenTrainingPercentIsNull() {
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testConstructor_GivenTrainingPercentIsBoundary() {
Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0);
assertThat(regression.getTrainingPercent(), equalTo(1.0));
regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testConstructor_GivenTrainingPercentIsLessThanOne() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testConstructor_GivenTrainingPercentIsGreaterThan100() {
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001));
() -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001));
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
}
public void testGetPredictionFieldName() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
assertThat(regression.getPredictionFieldName(), equalTo("result"));
regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0);
assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction"));
}
public void testGetTrainingPercent() {
Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0);
assertThat(regression.getTrainingPercent(), equalTo(50.0));
// Boundary condition: training_percent == 1.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0);
assertThat(regression.getTrainingPercent(), equalTo(1.0));
// Boundary condition: training_percent == 100.0
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
// training_percent == null, default applied
regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null);
assertThat(regression.getTrainingPercent(), equalTo(100.0));
}
public void testFieldCardinalityLimitsIsNonNull() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue())));
}

View File

@ -1470,6 +1470,7 @@ setup:
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"prediction_field_name": "foo_prediction",
"training_percent": 60.3
}
}}
@ -1809,6 +1810,7 @@ setup:
"eta": 0.5,
"maximum_number_trees": 400,
"feature_bag_fraction": 0.3,
"prediction_field_name": "foo_prediction",
"training_percent": 60.3,
"num_top_classes": 2
}
@ -1844,6 +1846,7 @@ setup:
- match: { analysis: {
"regression":{
"dependent_variable": "foo",
"prediction_field_name": "foo_prediction",
"training_percent": 100.0
}
}}