diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 429dbb2d503..07cfa15b34c 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -126,8 +126,8 @@ import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.PhaseProgress; import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.classification.Classification; -import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.classification.MulticlassConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; @@ -1297,6 +1297,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setIndex("put-test-dest-index") .build()) .setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") + .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .build()) .setDescription("this is a regression") @@ -1331,6 +1332,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .setIndex("put-test-dest-index") .build()) .setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") + .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setNumTopClasses(1) .build()) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index edb92f4ce00..571a619ac13 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -92,7 +92,7 @@ public class Classification implements DataFrameAnalysis { } this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); - this.predictionFieldName = predictionFieldName; + this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; } @@ -113,6 +113,10 @@ public class Classification implements DataFrameAnalysis { return dependentVariable; } + public String getPredictionFieldName() { + return predictionFieldName; + } + public int getNumTopClasses() { return numTopClasses; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 5412156a1b4..6fa163dd65c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -70,7 +70,7 @@ public class Regression implements DataFrameAnalysis { } this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); - this.predictionFieldName = predictionFieldName; + this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; } @@ -89,6 +89,10 @@ public class Regression implements DataFrameAnalysis { return dependentVariable; } + public String getPredictionFieldName() { + return predictionFieldName; + } + public double getTrainingPercent() { return trainingPercent; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 59df68e7944..8306d08af79 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -73,6 +73,14 @@ public class ClassificationTests extends AbstractSerializingTestCase { + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + @Override protected Regression doParseInstance(XContentParser parser) throws IOException { return Regression.fromXContent(parser, false); @@ -42,32 +44,45 @@ public class RegressionTests extends AbstractSerializingTestCase { return Regression::new; } - public void testConstructor_GivenTrainingPercentIsNull() { - Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", null); - assertThat(regression.getTrainingPercent(), equalTo(100.0)); - } - - public void testConstructor_GivenTrainingPercentIsBoundary() { - Regression regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 1.0); - assertThat(regression.getTrainingPercent(), equalTo(1.0)); - regression = new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0); - assertThat(regression.getTrainingPercent(), equalTo(100.0)); - } - public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 0.999)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 100.0001)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } + public void testGetPredictionFieldName() { + Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0); + assertThat(regression.getPredictionFieldName(), equalTo("result")); + + regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0); + assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction")); + } + + public void testGetTrainingPercent() { + Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0); + assertThat(regression.getTrainingPercent(), equalTo(50.0)); + + // Boundary condition: training_percent == 1.0 + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0); + assertThat(regression.getTrainingPercent(), equalTo(1.0)); + + // Boundary condition: training_percent == 100.0 + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0); + assertThat(regression.getTrainingPercent(), equalTo(100.0)); + + // training_percent == null, default applied + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null); + assertThat(regression.getTrainingPercent(), equalTo(100.0)); + } + public void testFieldCardinalityLimitsIsNonNull() { assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index b8bea46422b..6e1828efcd4 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1470,6 +1470,7 @@ setup: "eta": 0.5, "maximum_number_trees": 400, "feature_bag_fraction": 0.3, + "prediction_field_name": "foo_prediction", "training_percent": 60.3 } }} @@ -1809,6 +1810,7 @@ setup: "eta": 0.5, "maximum_number_trees": 400, "feature_bag_fraction": 0.3, + "prediction_field_name": "foo_prediction", "training_percent": 60.3, "num_top_classes": 2 } @@ -1844,6 +1846,7 @@ setup: - match: { analysis: { "regression":{ "dependent_variable": "foo", + "prediction_field_name": "foo_prediction", "training_percent": 100.0 } }}