[ML] Fix r_squared eval when variance is 0 (#49439) (#49445)

This commit is contained in:
Benjamin Trent 2019-11-21 11:22:16 -05:00 committed by GitHub
parent 138d16ab9e
commit d9835f7fb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 1 deletions

View File

@ -81,7 +81,11 @@ public class RSquared implements RegressionMetric {
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
final boolean validResult = residualSumOfSquares == null
|| extendedStats == null
|| extendedStats.getCount() == 0
|| extendedStats.getVariance() == 0;
result = validResult ?
new Result(0.0) :
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
}

View File

@ -74,6 +74,21 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
public void testEvaluateWithSingleCountZeroVariance() {
Aggregations aggs = new Aggregations(Arrays.asList(
createSingleMetricAgg("residual_sum_of_squares", 1),
createExtendedStatsAgg("extended_stats_actual", 0.0, 1),
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
));
RSquared rSquared = new RSquared();
rSquared.process(aggs);
EvaluationMetricResult result = rSquared.getResult().get();
assertThat(result, equalTo(new RSquared.Result(0.0)));
}
public void testEvaluate_GivenMissingAggs() {
Aggregations aggs = new Aggregations(Collections.singletonList(
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)