[ML] better handle empty results when evaluating regression (#45745) (#45759)

* [ML] better handle empty results when evaluating regression

* adding new failure test to ml_security black list

* fixing equality check for regression results
This commit is contained in:
Benjamin Trent 2019-08-20 17:37:04 -05:00 committed by GitHub
parent 686739d456
commit ba7b677618
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 74 additions and 12 deletions

View File

@ -69,7 +69,7 @@ public class MeanSquaredError implements RegressionMetric {
@Override
public EvaluationMetricResult evaluate(Aggregations aggs) {
NumericMetricsAggregation.SingleValue value = aggs.get(AGG_NAME);
return value == null ? null : new Result(value.value());
return value == null ? new Result(0.0) : new Result(value.value());
}
@Override
@ -137,5 +137,18 @@ public class MeanSquaredError implements RegressionMetric {
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return error == other.error;
}
@Override
public int hashCode() {
return Objects.hashCode(error);
}
}
}

View File

@ -79,7 +79,7 @@ public class RSquared implements RegressionMetric {
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
return residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
null :
new Result(0.0) :
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
}
@ -148,5 +148,18 @@ public class RSquared implements RegressionMetric {
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Result other = (Result)o;
return value == other.value;
}
@Override
public int hashCode() {
return Objects.hashCode(value);
}
}
}

View File

@ -121,6 +121,12 @@ public class Regression implements Evaluation {
@Override
public void evaluate(SearchResponse searchResponse, ActionListener<List<EvaluationMetricResult>> listener) {
List<EvaluationMetricResult> results = new ArrayList<>(metrics.size());
if (searchResponse.getHits().getTotalHits().value == 0) {
listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields",
actualField,
predictedField));
return;
}
for (RegressionMetric metric : metrics) {
results.add(metric.evaluate(searchResponse.getAggregations()));
}

View File

@ -81,7 +81,7 @@ public class Recall extends AbstractConfusionMatrixMetric {
for (int i = 0; i < recalls.length; i++) {
double threshold = thresholds[i];
Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP));
Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN));
Filter fnAgg = aggs.get(aggName(classInfo, threshold, Condition.FN));
long tp = tpAgg.getDocCount();
long fn = fnAgg.getDocCount();
recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn);

View File

@ -17,9 +17,7 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -64,7 +62,7 @@ public class MeanSquaredErrorTests extends AbstractSerializingTestCase<MeanSquar
MeanSquaredError mse = new MeanSquaredError();
EvaluationMetricResult result = mse.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(new MeanSquaredError.Result(0.0)));
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {

View File

@ -18,9 +18,7 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -70,17 +68,18 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
RSquared rSquared = new RSquared();
EvaluationMetricResult result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
public void testEvaluate_GivenMissingAggs() {
EvaluationMetricResult zeroResult = new RSquared.Result(0.0);
Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
RSquared rSquared = new RSquared();
EvaluationMetricResult result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(zeroResult));
aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
@ -88,7 +87,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
));
result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(zeroResult));
aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377),
@ -96,7 +95,7 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
));
result = rSquared.evaluate(aggs);
assertThat(result, is(nullValue()));
assertThat(result, equalTo(zeroResult));
}
private static NumericMetricsAggregation.SingleValue createSingleMetricAgg(String name, double value) {

View File

@ -86,6 +86,8 @@ integTest.runner {
'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds',
'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds',
'ml/evaluate_data_frame/Test regression given evaluation with empty metrics',
'ml/evaluate_data_frame/Test regression given missing actual_field',
'ml/evaluate_data_frame/Test regression given missing predicted_field',
'ml/delete_job_force/Test cannot force delete a non-existent job',
'ml/delete_model_snapshot/Test delete snapshot missing snapshotId',
'ml/delete_model_snapshot/Test delete snapshot missing job_id',

View File

@ -602,3 +602,34 @@ setup:
- match: { regression.mean_squared_error.error: 28.67749840974834 }
- match: { regression.r_squared.value: 0.8551031778603486 }
---
"Test regression given missing actual_field":
- do:
catch: /No documents found containing both \[missing, regression_field_pred\] fields/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "missing",
"predicted_field": "regression_field_pred"
}
}
}
---
"Test regression given missing predicted_field":
- do:
catch: /No documents found containing both \[regression_field_act, missing\] fields/
ml.evaluate_data_frame:
body: >
{
"index": "utopia",
"evaluation": {
"regression": {
"actual_field": "regression_field_act",
"predicted_field": "missing"
}
}
}