From 2c7ff812da49acb5e0cacf60588d523a07dc9dd6 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 16 Jul 2019 11:11:31 -0500 Subject: [PATCH] [ML] Add r_squared eval metric to regression (#44248) (#44378) * [ML] Add r_squared eval metric to regression * fixing tests and binarysoftclassification class * Update RSquared.java * Update x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java Co-Authored-By: David Kyle * removing unnecessary debug test --- .../MlEvaluationNamedXContentProvider.java | 5 + .../evaluation/regression/RSquaredMetric.java | 131 +++++++++++++++ .../evaluation/regression/Regression.java | 8 +- .../BinarySoftClassification.java | 21 ++- .../client/MachineLearningIT.java | 10 +- .../client/RestHighLevelClientTests.java | 21 ++- ...usionMatrixMetricConfusionMatrixTests.java | 2 +- .../regression/RSquaredMetricResultTests.java | 53 ++++++ .../regression/RSquaredMetricTests.java | 49 ++++++ .../regression/RegressionTests.java | 14 +- .../BinarySoftClassificationTests.java | 85 ++++++++++ .../MlEvaluationNamedXContentProvider.java | 8 + .../evaluation/regression/RSquared.java | 152 ++++++++++++++++++ .../evaluation/regression/Regression.java | 3 +- .../evaluation/regression/RSquaredTests.java | 116 +++++++++++++ .../regression/RegressionTests.java | 17 +- .../test/ml/evaluate_data_frame.yml | 19 +++ 17 files changed, 694 insertions(+), 20 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index b6f07fd4949..a28c498b1d5 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.ml.dataframe.evaluation; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.common.ParseField; @@ -49,6 +50,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric::fromXContent), // Evaluation metrics results new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), @@ -56,6 +59,8 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(RSquaredMetric.NAME), RSquaredMetric.Result::fromXContent), new NamedXContentRegistry.Entry( EvaluationMetric.Result.class, new ParseField(MeanSquaredErrorMetric.NAME), MeanSquaredErrorMetric.Result::fromXContent), new NamedXContentRegistry.Entry( diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java new file mode 100644 index 00000000000..968489a3038 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetric.java @@ -0,0 +1,131 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +/** + * Calculates R-Squared between two known numerical fields. + * + * equation: mse = 1 - SSres/SStot + * such that, + * SSres = Σ(y - y´)^2 + * SStot = Σ(y - y_mean)^2 + */ +public class RSquaredMetric implements EvaluationMetric { + + public static final String NAME = "r_squared"; + + private static final ObjectParser PARSER = + new ObjectParser<>("r_squared", true, RSquaredMetric::new); + + public static RSquaredMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public RSquaredMetric() { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + // create static hash code from name as there are currently no unique fields per class instance + return Objects.hashCode(NAME); + } + + @Override + public String getName() { + return NAME; + } + + public static class Result implements EvaluationMetric.Result { + + public static final ParseField VALUE = new ParseField("value"); + private final double value; + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("r_squared_result", true, args -> new Result((double) args[0])); + + static { + PARSER.declareDouble(constructorArg(), VALUE); + } + + public Result(double value) { + this.value = value; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VALUE.getPreferredName(), value); + builder.endObject(); + return builder; + } + + public double getValue() { + return value; + } + + @Override + public String getMetricName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(that.value, this.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java index 13b14f6e0b0..79b9ab6eb1d 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/Regression.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -84,8 +85,11 @@ public class Regression implements Evaluation { } public Regression(String actualField, String predictedField, @Nullable List metrics) { - this.actualField = actualField; - this.predictedField = predictedField; + this.actualField = Objects.requireNonNull(actualField); + this.predictedField = Objects.requireNonNull(predictedField); + if (metrics != null) { + metrics.sort(Comparator.comparing(EvaluationMetric::getName)); + } this.metrics = metrics; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index 6d5fa04da38..cb531c6ab04 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -52,6 +53,7 @@ public class BinarySoftClassification implements Evaluation { public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, + true, args -> new BinarySoftClassification((String) args[0], (String) args[1], (List) args[2])); static { @@ -80,6 +82,10 @@ public class BinarySoftClassification implements Evaluation { */ private final List metrics; + public BinarySoftClassification(String actualField, String predictedField) { + this(actualField, predictedField, (List)null); + } + public BinarySoftClassification(String actualField, String predictedProbabilityField, EvaluationMetric... metric) { this(actualField, predictedProbabilityField, Arrays.asList(metric)); } @@ -88,7 +94,10 @@ public class BinarySoftClassification implements Evaluation { @Nullable List metrics) { this.actualField = Objects.requireNonNull(actualField); this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField); - this.metrics = Objects.requireNonNull(metrics); + if (metrics != null) { + metrics.sort(Comparator.comparing(EvaluationMetric::getName)); + } + this.metrics = metrics; } @Override @@ -102,11 +111,13 @@ public class BinarySoftClassification implements Evaluation { builder.field(ACTUAL_FIELD.getPreferredName(), actualField); builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); - builder.startObject(METRICS.getPreferredName()); - for (EvaluationMetric metric : metrics) { - builder.field(metric.getName(), metric); + if (metrics != null) { + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); } - builder.endObject(); builder.endObject(); return builder; diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index d99d9ecd29d..5a92602d004 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -124,6 +124,7 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -1597,16 +1598,21 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { .add(docForRegression(regressionIndex, 0.5, 0.9)); // #9 highLevelClient().bulk(regressionBulk, RequestOptions.DEFAULT); - evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, new Regression(actualRegression, probabilityRegression)); + evaluateDataFrameRequest = new EvaluateDataFrameRequest(regressionIndex, + new Regression(actualRegression, probabilityRegression, new MeanSquaredErrorMetric(), new RSquaredMetric())); evaluateDataFrameResponse = execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(Regression.NAME)); - assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(1)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(2)); MeanSquaredErrorMetric.Result mseResult = evaluateDataFrameResponse.getMetricByName(MeanSquaredErrorMetric.NAME); assertThat(mseResult.getMetricName(), equalTo(MeanSquaredErrorMetric.NAME)); assertThat(mseResult.getError(), closeTo(0.061000000, 1e-9)); + + RSquaredMetric.Result rSquaredResult = evaluateDataFrameResponse.getMetricByName(RSquaredMetric.NAME); + assertThat(rSquaredResult.getMetricName(), equalTo(RSquaredMetric.NAME)); + assertThat(rSquaredResult.getValue(), closeTo(-5.1000000000000005, 1e-9)); } private static XContentBuilder defaultMappingForTest() throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 77dc9ee53fd..98c5cf87030 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -61,6 +61,7 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.evaluation.regression.MeanSquaredErrorMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.regression.RSquaredMetric; import org.elasticsearch.client.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; @@ -676,7 +677,7 @@ public class RestHighLevelClientTests extends ESTestCase { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(34, namedXContents.size()); + assertEquals(36, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -716,12 +717,22 @@ public class RestHighLevelClientTests extends ESTestCase { assertTrue(names.contains(TimeSyncConfig.NAME)); assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); assertThat(names, hasItems(BinarySoftClassification.NAME, Regression.NAME)); - assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); assertThat(names, - hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME)); - assertEquals(Integer.valueOf(5), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + hasItems(AucRocMetric.NAME, + PrecisionMetric.NAME, + RecallMetric.NAME, + ConfusionMatrixMetric.NAME, + MeanSquaredErrorMetric.NAME, + RSquaredMetric.NAME)); + assertEquals(Integer.valueOf(6), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); assertThat(names, - hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, MeanSquaredErrorMetric.NAME)); + hasItems(AucRocMetric.NAME, + PrecisionMetric.NAME, + RecallMetric.NAME, + ConfusionMatrixMetric.NAME, + MeanSquaredErrorMetric.NAME, + RSquaredMetric.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java index 28eb221b318..b54bcd53fc4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java @@ -26,7 +26,7 @@ import java.io.IOException; public class ConfusionMatrixMetricConfusionMatrixTests extends AbstractXContentTestCase { - static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() { + public static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() { return new ConfusionMatrixMetric.ConfusionMatrix(randomInt(), randomInt(), randomInt(), randomInt()); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricResultTests.java new file mode 100644 index 00000000000..3d18418a752 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricResultTests.java @@ -0,0 +1,53 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class RSquaredMetricResultTests extends AbstractXContentTestCase { + + public static RSquaredMetric.Result randomResult() { + return new RSquaredMetric.Result(randomDouble()); + } + + @Override + protected RSquaredMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected RSquaredMetric.Result doParseInstance(XContentParser parser) throws IOException { + return RSquaredMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricTests.java new file mode 100644 index 00000000000..ab8b9e0f7af --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RSquaredMetricTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.regression; + +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class RSquaredMetricTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected RSquaredMetric createTestInstance() { + return new RSquaredMetric(); + } + + @Override + protected RSquaredMetric doParseInstance(XContentParser parser) throws IOException { + return RSquaredMetric.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java index f5b3db9cec8..89e4823b93e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/regression/RegressionTests.java @@ -18,13 +18,15 @@ */ package org.elasticsearch.client.ml.dataframe.evaluation.regression; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; -import java.util.Collections; +import java.util.ArrayList; +import java.util.List; import java.util.function.Predicate; public class RegressionTests extends AbstractXContentTestCase { @@ -36,9 +38,16 @@ public class RegressionTests extends AbstractXContentTestCase { @Override protected Regression createTestInstance() { + List metrics = new ArrayList<>(); + if (randomBoolean()) { + metrics.add(new MeanSquaredErrorMetric()); + } + if (randomBoolean()) { + metrics.add(new RSquaredMetric()); + } return randomBoolean() ? new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10)) : - new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), Collections.singletonList(new MeanSquaredErrorMetric())); + new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); } @Override @@ -56,4 +65,5 @@ public class RegressionTests extends AbstractXContentTestCase { // allow unknown fields in the root of the object only return field -> !field.isEmpty(); } + } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java new file mode 100644 index 00000000000..2fb8a21e3a1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java @@ -0,0 +1,85 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.function.Predicate; + +public class BinarySoftClassificationTests extends AbstractXContentTestCase { + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected BinarySoftClassification createTestInstance() { + List metrics = new ArrayList<>(); + if (randomBoolean()) { + metrics.add(new AucRocMetric(randomBoolean())); + } + if (randomBoolean()) { + metrics.add(new PrecisionMetric(Arrays.asList(randomArray(1, + 4, + Double[]::new, + BinarySoftClassificationTests::randomDouble)))); + } + if (randomBoolean()) { + metrics.add(new RecallMetric(Arrays.asList(randomArray(1, + 4, + Double[]::new, + BinarySoftClassificationTests::randomDouble)))); + } + if (randomBoolean()) { + metrics.add(new ConfusionMatrixMetric(Arrays.asList(randomArray(1, + 4, + Double[]::new, + BinarySoftClassificationTests::randomDouble)))); + } + return randomBoolean() ? + new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10)) : + new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics.isEmpty() ? null : metrics); + } + + @Override + protected BinarySoftClassification doParseInstance(XContentParser parser) throws IOException { + return BinarySoftClassification.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index f713aa0033d..a2aa8e74918 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.MeanSquaredError; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RSquared; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.Regression; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; @@ -42,6 +43,7 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider // Regression metrics namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME, MeanSquaredError::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(RegressionMetric.class, RSquared.NAME, RSquared::fromXContent)); return namedXContent; } @@ -66,6 +68,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(RegressionMetric.class, + RSquared.NAME.getPreferredName(), + RSquared::new)); // Evaluation Metrics Results namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), @@ -77,6 +82,9 @@ public class MlEvaluationNamedXContentProvider implements NamedXContentProvider namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, MeanSquaredError.NAME.getPreferredName(), MeanSquaredError.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, + RSquared.NAME.getPreferredName(), + RSquared.Result::new)); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java new file mode 100644 index 00000000000..871f166733f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.metrics.ExtendedStats; +import org.elasticsearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.text.MessageFormat; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Objects; + +/** + * Calculates R-Squared between two known numerical fields. + * + * equation: R-Squared = 1 - SSres/SStot + * such that, + * SSres = Σ(y - y´)^2, The residual sum of squares + * SStot = Σ(y - y_mean)^2, The total sum of squares + */ +public class RSquared implements RegressionMetric { + + public static final ParseField NAME = new ParseField("r_squared"); + + private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;"; + private static final String SS_RES = "residual_sum_of_squares"; + + private static String buildScript(Object... args) { + return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args); + } + + private static final ObjectParser PARSER = + new ObjectParser<>("r_squared", true, RSquared::new); + + public static RSquared fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public RSquared(StreamInput in) { + + } + + public RSquared() { + + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public List aggs(String actualField, String predictedField) { + return Arrays.asList( + AggregationBuilders.sum(SS_RES).script(new Script(buildScript(actualField, predictedField))), + AggregationBuilders.extendedStats(ExtendedStatsAggregationBuilder.NAME + "_actual").field(actualField)); + } + + @Override + public EvaluationMetricResult evaluate(Aggregations aggs) { + NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES); + ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual"); + // extendedStats.getVariance() is the statistical sumOfSquares divided by count + return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ? + null : + new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount()))); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + // create static hash code from name as there are currently no unique fields per class instance + return Objects.hashCode(NAME.getPreferredName()); + } + + public static class Result implements EvaluationMetricResult { + + private static final String VALUE = "value"; + private final double value; + + public Result(double value) { + this.value = value; + } + + public Result(StreamInput in) throws IOException { + this.value = in.readDouble(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(value); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VALUE, value); + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index 455f44ae3c1..e3869dce2ee 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -94,8 +94,9 @@ public class Regression implements Evaluation { } private static List defaultMetrics() { - List defaultMetrics = new ArrayList<>(1); + List defaultMetrics = new ArrayList<>(2); defaultMetrics.add(new MeanSquaredError()); + defaultMetrics.add(new RSquared()); return defaultMetrics; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java new file mode 100644 index 00000000000..97ec16494e0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquaredTests.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.metrics.ExtendedStats; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RSquaredTests extends AbstractSerializingTestCase { + + @Override + protected RSquared doParseInstance(XContentParser parser) throws IOException { + return RSquared.fromXContent(parser); + } + + @Override + protected RSquared createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RSquared::new; + } + + public static RSquared createRandom() { + return new RSquared(); + } + + public void testEvaluate() { + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("residual_sum_of_squares", 10_111), + createExtendedStatsAgg("extended_stats_actual", 155.23, 1000), + createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + RSquared rSquared = new RSquared(); + EvaluationMetricResult result = rSquared.evaluate(aggs); + + String expected = "{\"value\":0.9348643947690524}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + public void testEvaluateWithZeroCount() { + Aggregations aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("residual_sum_of_squares", 0), + createExtendedStatsAgg("extended_stats_actual", 0.0, 0), + createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000), + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + RSquared rSquared = new RSquared(); + EvaluationMetricResult result = rSquared.evaluate(aggs); + assertThat(result, is(nullValue())); + } + + public void testEvaluate_GivenMissingAggs() { + Aggregations aggs = new Aggregations(Collections.singletonList( + createSingleMetricAgg("some_other_single_metric_agg", 0.2377) + )); + + RSquared rSquared = new RSquared(); + EvaluationMetricResult result = rSquared.evaluate(aggs); + assertThat(result, is(nullValue())); + + aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("some_other_single_metric_agg", 0.2377), + createSingleMetricAgg("residual_sum_of_squares", 0.2377) + )); + + result = rSquared.evaluate(aggs); + assertThat(result, is(nullValue())); + + aggs = new Aggregations(Arrays.asList( + createSingleMetricAgg("some_other_single_metric_agg", 0.2377), + createExtendedStatsAgg("extended_stats_actual",100, 50) + )); + + result = rSquared.evaluate(aggs); + assertThat(result, is(nullValue())); + } + + private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { + NumericMetricsAggregation.SingleValue agg = mock(NumericMetricsAggregation.SingleValue.class); + when(agg.getName()).thenReturn(name); + when(agg.value()).thenReturn(value); + return agg; + } + + private static ExtendedStats createExtendedStatsAgg(String name, double variance, long count) { + ExtendedStats agg = mock(ExtendedStats.class); + when(agg.getName()).thenReturn(name); + when(agg.getVariance()).thenReturn(variance); + when(agg.getCount()).thenReturn(count); + return agg; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java index 33ce6e56ff5..d0bcc1a11f4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RegressionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -32,8 +33,20 @@ public class RegressionTests extends AbstractSerializingTestCase { } public static Regression createRandom() { - List metrics = Collections.singletonList(MeanSquaredErrorTests.createRandom()); - return new Regression(randomAlphaOfLength(10), randomAlphaOfLength(10), randomBoolean() ? null : metrics); + List metrics = new ArrayList<>(); + if (randomBoolean()) { + metrics.add(MeanSquaredErrorTests.createRandom()); + } + if (randomBoolean()) { + metrics.add(RSquaredTests.createRandom()); + } + return new Regression(randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomBoolean() ? + null : + metrics.isEmpty() ? + null : + metrics); } @Override diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index d0ed46b0f04..46d903977eb 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -567,6 +567,24 @@ setup: } - match: { regression.mean_squared_error.error: 28.67749840974834 } + - is_false: regression.r_squared.value +--- +"Test regression r_squared": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "regression_field_pred", + "metrics": { "r_squared": {} } + } + } + } + - match: { regression.r_squared.value: 0.8551031778603486 } + - is_false: regression.mean_squared_error --- "Test regression with null metrics": - do: @@ -583,3 +601,4 @@ setup: } - match: { regression.mean_squared_error.error: 28.67749840974834 } + - match: { regression.r_squared.value: 0.8551031778603486 }