parent
138d16ab9e
commit
d9835f7fb4
|
@ -81,7 +81,11 @@ public class RSquared implements RegressionMetric {
|
||||||
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
|
NumericMetricsAggregation.SingleValue residualSumOfSquares = aggs.get(SS_RES);
|
||||||
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
|
ExtendedStats extendedStats = aggs.get(ExtendedStatsAggregationBuilder.NAME + "_actual");
|
||||||
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
|
// extendedStats.getVariance() is the statistical sumOfSquares divided by count
|
||||||
result = residualSumOfSquares == null || extendedStats == null || extendedStats.getCount() == 0 ?
|
final boolean validResult = residualSumOfSquares == null
|
||||||
|
|| extendedStats == null
|
||||||
|
|| extendedStats.getCount() == 0
|
||||||
|
|| extendedStats.getVariance() == 0;
|
||||||
|
result = validResult ?
|
||||||
new Result(0.0) :
|
new Result(0.0) :
|
||||||
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
|
new Result(1 - (residualSumOfSquares.value() / (extendedStats.getVariance() * extendedStats.getCount())));
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,6 +74,21 @@ public class RSquaredTests extends AbstractSerializingTestCase<RSquared> {
|
||||||
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testEvaluateWithSingleCountZeroVariance() {
|
||||||
|
Aggregations aggs = new Aggregations(Arrays.asList(
|
||||||
|
createSingleMetricAgg("residual_sum_of_squares", 1),
|
||||||
|
createExtendedStatsAgg("extended_stats_actual", 0.0, 1),
|
||||||
|
createExtendedStatsAgg("some_other_extended_stats",99.1, 10_000),
|
||||||
|
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||||
|
));
|
||||||
|
|
||||||
|
RSquared rSquared = new RSquared();
|
||||||
|
rSquared.process(aggs);
|
||||||
|
|
||||||
|
EvaluationMetricResult result = rSquared.getResult().get();
|
||||||
|
assertThat(result, equalTo(new RSquared.Result(0.0)));
|
||||||
|
}
|
||||||
|
|
||||||
public void testEvaluate_GivenMissingAggs() {
|
public void testEvaluate_GivenMissingAggs() {
|
||||||
Aggregations aggs = new Aggregations(Collections.singletonList(
|
Aggregations aggs = new Aggregations(Collections.singletonList(
|
||||||
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
createSingleMetricAgg("some_other_single_metric_agg", 0.2377)
|
||||||
|
|
Loading…
Reference in New Issue