From ba7b6776181cf656b72a83c8dd578f4ec9dc5834 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 20 Aug 2019 17:37:04 -0500 Subject: [PATCH] [ML] better handle empty results when evaluating regression (#45745) (#45759) * [ML] better handle empty results when evaluating regression * adding new failure test to ml_security black list * fixing equality check for regression results --- .../regression/MeanSquaredError.java | 15 ++++++++- .../evaluation/regression/RSquared.java | 15 ++++++++- .../evaluation/regression/Regression.java | 6 ++++ .../evaluation/softclassification/Recall.java | 2 +- .../regression/MeanSquaredErrorTests.java | 4 +-- .../evaluation/regression/RSquaredTests.java | 11 +++---- .../ml/qa/ml-with-security/build.gradle | 2 ++ .../test/ml/evaluate_data_frame.yml | 31 +++++++++++++++++++ 8 files changed, 74 insertions(+), 12 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java index 8dd922b6ac2..e48cb46b5c0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredError.java @@ -69,7 +69,7 @@ public class MeanSquaredError implements RegressionMetric { @Override public EvaluationMetricResult evaluate(Aggregations aggs) { NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME); - return value == null ? null : new Result(value.value()); + return value == null ? new Result(0.0) : new Result(value.value()); } @Override @@ -137,5 +137,18 @@ public class MeanSquaredError implements RegressionMetric { builder.endObject(); return builder; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result other = (Result)o; + return error == other.error; + } + + @Override + public int hashCode() { + return Objects.hashCode(error); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java index 871f166733f..a5530656183 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/RSquared.java @@ -79,7 +79,7 @@ public class RSquared implements RegressionMetric { ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual"); // extendedStats.getVariance() is the statistical sumOfSquares divided by count return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ? - null : + new Result(0.0) : new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount()))); } @@ -148,5 +148,18 @@ public class RSquared implements RegressionMetric { builder.endObject(); return builder; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result other = (Result)o; + return value == other.value; + } + + @Override + public int hashCode() { + return Objects.hashCode(value); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java index e3869dce2ee..610c065fd81 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/Regression.java @@ -121,6 +121,12 @@ public class Regression implements Evaluation { @Override public void evaluate(SearchResponse searchResponse, ActionListener> listener) { List results = new ArrayList<>(metrics.size()); + if (searchResponse.getHits().getTotalHits().value == 0) { + listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", + actualField, + predictedField)); + return; + } for (RegressionMetric metric : metrics) { results.add(metric.evaluate(searchResponse.getAggregations())); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java index 5c4ab57241d..f7103aceeda 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java @@ -81,7 +81,7 @@ public class Recall extends AbstractConfusionMatrixMetric { for (int i = 0; i < recalls.length; i++) { double threshold = thresholds[i]; Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); - Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN)); + Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN)); long tp = tpAgg.getDocCount(); long fn = fnAgg.getDocCount(); recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java index 43513514747..a22c499220c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/regression/MeanSquaredErrorTests.java @@ -17,9 +17,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collections; -import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.nullValue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -64,7 +62,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase { RSquared rSquared = new RSquared(); EvaluationMetricResult result = rSquared.evaluate(aggs); - assertThat(result, is(nullValue())); + assertThat(result, equalTo(new RSquared.Result(0.0))); } public void testEvaluate_GivenMissingAggs() { + EvaluationMetricResult zeroResult = new RSquared.Result(0.0); Aggregations aggs = new Aggregations(Collections.singletonList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377) )); RSquared rSquared = new RSquared(); EvaluationMetricResult result = rSquared.evaluate(aggs); - assertThat(result, is(nullValue())); + assertThat(result, equalTo(zeroResult)); aggs = new Aggregations(Arrays.asList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377), @@ -88,7 +87,7 @@ public class RSquaredTests extends AbstractSerializingTestCase { )); result = rSquared.evaluate(aggs); - assertThat(result, is(nullValue())); + assertThat(result, equalTo(zeroResult)); aggs = new Aggregations(Arrays.asList( createSingleMetricAgg("some_other_single_metric_agg", 0.2377), @@ -96,7 +95,7 @@ public class RSquaredTests extends AbstractSerializingTestCase { )); result = rSquared.evaluate(aggs); - assertThat(result, is(nullValue())); + assertThat(result, equalTo(zeroResult)); } private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) { diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 8342a0f9430..3f7bebe8514 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -86,6 +86,8 @@ integTest.runner { 'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds', 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds', 'ml/evaluate_data_frame/Test regression given evaluation with empty metrics', + 'ml/evaluate_data_frame/Test regression given missing actual_field', + 'ml/evaluate_data_frame/Test regression given missing predicted_field', 'ml/delete_job_force/Test cannot force delete a non-existent job', 'ml/delete_model_snapshot/Test delete snapshot missing snapshotId', 'ml/delete_model_snapshot/Test delete snapshot missing job_id', diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 46d903977eb..a4d3c1f1979 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -602,3 +602,34 @@ setup: - match: { regression.mean_squared_error.error: 28.67749840974834 } - match: { regression.r_squared.value: 0.8551031778603486 } +--- +"Test regression given missing actual_field": + - do: + catch: /No documents found containing both \[missing, regression_field_pred\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "missing", + "predicted_field": "regression_field_pred" + } + } + } + +--- +"Test regression given missing predicted_field": + - do: + catch: /No documents found containing both \[regression_field_act, missing\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "regression": { + "actual_field": "regression_field_act", + "predicted_field": "missing" + } + } + }