parent
138d16ab9e
commit
d9835f7fb4
|
@ -81,7 +81,11 @@ public class RSquared implements RegressionMetric {
|
|||
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
|
||||
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
|
||||
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
|
||||
result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
|
||||
final boolean validResult = residualSumOfSquares == null
|
||||
|| extendedStats == null
|
||||
|| extendedStats.getCount() == 0
|
||||
|| extendedStats.getVariance() == 0;
|
||||
result = validResult ?
|
||||
new Result(0.0) :
|
||||
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
|
||||
}
|
||||
|
|
|
@ -74,6 +74,21 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
|||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||
}
|
||||
|
||||
public void testEvaluateWithSingleCountZeroVariance() {
|
||||
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||
createSingleMetricAgg("residual_sum_of_squares", 1),
|
||||
createExtendedStatsAgg("extended_stats_actual", 0.0, 1),
|
||||
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
));
|
||||
|
||||
RSquared rSquared = new RSquared();
|
||||
rSquared.process(aggs);
|
||||
|
||||
EvaluationMetricResult result = rSquared.getResult().get();
|
||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||
}
|
||||
|
||||
public void testEvaluate_GivenMissingAggs() {
|
||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||
|
|
Loading…
Reference in New Issue